├── tests
├── __init__.py
├── test_sanity.py
├── DataGenerators.py
├── test_random_crop.py
├── test_resample_augmentations.py
├── test_augment_zoom.py
├── test_spatial_transformations.py
├── test_axis_mirroring.py
├── test_multithreaded_augmenter.py
├── test_color_augmentations.py
└── test_DataLoader.py
├── batchgenerators
├── __init__.py
├── datasets
│ ├── __init__.py
│ └── cifar.py
├── examples
│ ├── __init__.py
│ ├── brats2017
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── readme.md
│ │ ├── brats2017_dataloader_2D.py
│ │ └── brats2017_preprocessing.py
│ ├── multithreaded_dataloading.py
│ └── cifar10.py
├── transforms
│ ├── __init__.py
│ ├── abstract_transforms.py
│ ├── resample_transforms.py
│ ├── sample_normalization_transforms.py
│ ├── channel_selection_transforms.py
│ ├── crop_and_pad_transforms.py
│ └── color_transforms.py
├── utilities
│ ├── __init__.py
│ ├── data_splitting.py
│ ├── custom_types.py
│ └── file_and_folder_operations.py
├── augmentations
│ ├── __init__.py
│ ├── normalizations.py
│ ├── resample_augmentations.py
│ ├── noise_augmentations.py
│ ├── color_augmentations.py
│ └── crop_and_pad_augmentations.py
└── dataloading
│ ├── __init__.py
│ ├── dataset.py
│ ├── single_threaded_augmenter.py
│ ├── nondet_multi_threaded_augmenter.py
│ └── data_loader.py
├── setup.cfg
├── .arcconfig
├── DKFZ_Logo.png
├── HIP_Logo.png
├── requirements.txt
├── .travis.yml
├── Makefile
├── setup.py
├── .gitignore
├── Readme.md
└── LICENSE
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/batchgenerators/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/batchgenerators/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/batchgenerators/examples/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/batchgenerators/transforms/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/batchgenerators/utilities/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/batchgenerators/augmentations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/batchgenerators/dataloading/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/batchgenerators/examples/brats2017/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description_file = README.md
--------------------------------------------------------------------------------
/.arcconfig:
--------------------------------------------------------------------------------
1 | {
2 | "phabricator.uri": "https://phabricator.mitk.org/"
3 | }
4 |
--------------------------------------------------------------------------------
/DKFZ_Logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIC-DKFZ/batchgenerators/HEAD/DKFZ_Logo.png
--------------------------------------------------------------------------------
/HIP_Logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIC-DKFZ/batchgenerators/HEAD/HIP_Logo.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pillow>=7.1.2
2 | threadpoolctl
3 | scikit-learn
4 | numpy>=1.10.2
5 | scipy
6 | scikit-image
7 | scikit-learn
8 | unittest2
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 |
3 | matrix:
4 | include:
5 | - os: linux
6 | python: "3.7"
7 | dist: bionic
8 | # - os: windows
9 | # python: "3.6"
10 |
11 | install:
12 | - pip install .
13 | script:
14 | - pytest
15 |
--------------------------------------------------------------------------------
/batchgenerators/examples/brats2017/config.py:
--------------------------------------------------------------------------------
1 | brats_preprocessed_folder = "/media/fabian/DeepLearningData/BraTS2017_preprocessed"
2 | brats_folder_with_downloaded_train_data = "/media/fabian/DeepLearningData/Brats17TrainingData"
3 | num_threads_for_brats_example = 8
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | init:
2 | pip install numpy ; \
3 | pip install -r requirements.txt
4 |
5 | test:
6 | python -m unittest discover
7 |
8 | install_develop:
9 | python setup.py develop
10 |
11 | install:
12 | python setup.py install
13 |
14 | documentation:
15 | sphinx-apidoc -e -f DeepLearningBatchGeneratorUtils -o doc/
16 |
--------------------------------------------------------------------------------
/batchgenerators/dataloading/dataset.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 |
3 |
4 | class Dataset(object):
5 | def __init__(self):
6 | __metaclass__ = ABCMeta
7 |
8 | @abstractmethod
9 | def __getitem__(self, item):
10 | '''
11 | needs to return a data_dict for the sample at the position item
12 | :param item:
13 | :return:
14 | '''
15 | pass
16 |
17 | @abstractmethod
18 | def __len__(self):
19 | '''
20 | returns how many items the dataset has
21 | :return:
22 | '''
23 | pass
24 |
25 |
26 |
27 |
--------------------------------------------------------------------------------
/tests/test_sanity.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import unittest
17 |
18 |
19 | class TestSanity(unittest.TestCase):
20 |
21 | def test_sanity(self):
22 | self.assertTrue(1 == 1, "Sanity test failed")
23 |
24 |
25 | if __name__ == '__main__':
26 | unittest.main()
27 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(name='batchgenerators',
4 | version='0.25.1',
5 | description='Data augmentation toolkit',
6 | url='https://github.com/MIC-DKFZ/batchgenerators',
7 | author='Division of Medical Image Computing, German Cancer Research Center AND Applied Computer Vision Lab, '
8 | 'Helmholtz Imaging Platform',
9 | author_email='f.isensee@dkfz-heidelberg.de',
10 | license='Apache License Version 2.0, January 2004',
11 | packages=find_packages(exclude=["tests"]),
12 | install_requires=[
13 | "pillow>=7.1.2",
14 | "numpy>=1.10.2",
15 | "scipy",
16 | "scikit-image",
17 | "scikit-learn",
18 | "future",
19 | "pandas",
20 | "unittest2",
21 | "threadpoolctl"
22 | ],
23 | keywords=['data augmentation', 'deep learning', 'image segmentation', 'image classification',
24 | 'medical image analysis', 'medical image segmentation'],
25 | )
26 |
--------------------------------------------------------------------------------
/batchgenerators/examples/brats2017/readme.md:
--------------------------------------------------------------------------------
1 | # BraTS2017/2018 example
2 |
3 |
4 | This folder contains a complete example of how to process BraTS2017/2018 data with batchgenerators. You need to
5 | adapt the scripts to match your system and download location. The adaptation should be straightforward. All you need to
6 | do is to change the paths and the number of threads in config.py, then execute `brats2017_preprocessing.py` for
7 | preprocessing the data.
8 |
9 | Once preprocessed, have a look at `brats2017_dataloader_2D.py` and `brats2017_dataloader_3D.py` for how to implement
10 | data loader for 2D and 3D network training, respectively. Naturally these files contain everything you need, including
11 | data augmentation and multiprocessing. They are not designed to be just executed because there is no network training
12 | in there. The idea is that you look at them and execute what code they have in a controlled manner so that you can
13 | get a feel for how batchgenerators work. Questions? -> f.isensee@dkfz.de
14 |
15 | Why are these not IPython Notebooks? I don't like IPython Notebooks. Simple =)
16 |
17 | **IMPORTANT** these DataLoaders are not suited for test set prediction! You need to iterate over preprocessed test
18 | data by yourself.
19 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 |
91 | .idea
92 |
93 | .DS_Store
94 | .pytest_cache/*
--------------------------------------------------------------------------------
/batchgenerators/utilities/data_splitting.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from sklearn.model_selection import KFold
17 | import numpy as np
18 |
19 |
20 | def get_split_deterministic(all_keys, fold=0, num_splits=5, random_state=12345):
21 | """
22 | Splits a list of patient identifiers (or numbers) into num_splits folds and returns the split for fold fold.
23 | :param all_keys:
24 | :param fold:
25 | :param num_splits:
26 | :param random_state:
27 | :return:
28 | """
29 | all_keys_sorted = np.sort(list(all_keys))
30 | splits = KFold(n_splits=num_splits, shuffle=True, random_state=random_state)
31 | for i, (train_idx, test_idx) in enumerate(splits.split(all_keys_sorted)):
32 | if i == fold:
33 | train_keys = np.array(all_keys_sorted)[train_idx]
34 | test_keys = np.array(all_keys_sorted)[test_idx]
35 | break
36 | return train_keys, test_keys
37 |
--------------------------------------------------------------------------------
/batchgenerators/dataloading/single_threaded_augmenter.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | class SingleThreadedAugmenter(object):
17 | """
18 | Use this for debugging custom transforms. It does not use a background thread and you can therefore easily debug
19 | into your augmentations. This should not be used for training. If you want a generator that uses (a) background
20 | process(es), use MultiThreadedAugmenter.
21 | Args:
22 | data_loader (generator or DataLoaderBase instance): Your data loader. Must have a .next() function and return
23 | a dict that complies with our data structure
24 |
25 | transform (Transform instance): Any of our transformations. If you want to use multiple transformations then
26 | use our Compose transform! Can be None (in that case no transform will be applied)
27 | """
28 | def __init__(self, data_loader, transform):
29 | self.data_loader = data_loader
30 | self.transform = transform
31 |
32 | def __iter__(self):
33 | return self
34 |
35 | def __next__(self):
36 | item = next(self.data_loader)
37 | if self.transform is not None:
38 | item = self.transform(**item)
39 | return item
40 |
41 | def next(self):
42 | return self.__next__()
43 |
--------------------------------------------------------------------------------
/batchgenerators/utilities/custom_types.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from typing import Union, Tuple, Any, Callable
17 | import numpy as np
18 |
19 |
20 | ScalarType = Union[int, float, Tuple[float, float], Callable[..., Union[int, float]]]
21 |
22 |
23 | def sample_scalar(scalar_type: ScalarType, *args):
24 | if isinstance(scalar_type, (int, float)):
25 | return scalar_type
26 | elif isinstance(scalar_type, (list, tuple)):
27 | assert len(scalar_type) == 2, 'if list is provided, its length must be 2'
28 | assert scalar_type[0] <= scalar_type[1], 'if list is provided, first entry must be smaller or equal than second entry, ' \
29 | 'otherwise we cannot sample using np.random.uniform'
30 | if scalar_type[0] == scalar_type[1]:
31 | return scalar_type[0]
32 | return np.random.uniform(*scalar_type)
33 | elif callable(scalar_type):
34 | return scalar_type(*args)
35 | else:
36 | raise RuntimeError('Unknown type: %s. Expected: int, float, list, tuple, callable', type(scalar_type))
37 |
38 |
39 | if __name__ == '__main__':
40 | sample_scalar(0.5)
41 | sample_scalar((0, 1))
42 | sample_scalar(lambda: np.random.uniform(-1, 2))
43 | sample_scalar(lambda x, y: np.random.uniform(x, y), 0.5, 2)
44 |
--------------------------------------------------------------------------------
/tests/DataGenerators.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import numpy as np
17 | from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
18 |
19 |
20 | class BasicDataLoader(SlimDataLoaderBase):
21 | """
22 | data is a tuple of images (b,c,x,y(,z)) and segmentations (b,c,x,y(,z))
23 | """
24 |
25 | def generate_train_batch(self):
26 | #Sample randomly from data
27 | idx = np.random.choice(self._data[0].shape[0], self.batch_size, True, None)
28 | # copy data to ensure that we are not modifying the original dataset with subsequeng augmentation techniques!
29 | x = np.array(self._data[0][idx])
30 | y = np.array(self._data[1][idx])
31 | data_dict = {"data": x,
32 | "seg": y}
33 | return data_dict
34 |
35 |
36 | class DummyGenerator(SlimDataLoaderBase):
37 | """
38 | creates random data and seg of shape dataset_size and returns those.
39 | """
40 | def __init__(self, dataset_size, batch_size, fill_data='random', fill_seg='ones'):
41 | if fill_data == "random":
42 | data = np.random.random(dataset_size)
43 | elif fill_data == "ones":
44 | data = np.ones(dataset_size)
45 | else:
46 | raise NotImplementedError
47 |
48 | if fill_seg == "ones":
49 | seg = np.ones(dataset_size)
50 | else:
51 | raise NotImplementedError
52 |
53 | super(DummyGenerator, self).__init__((data, seg), batch_size, None)
54 |
55 | def generate_train_batch(self):
56 | idx = np.random.choice(self._data[0].shape[0], self.batch_size)
57 |
58 | data = self._data[0][idx]
59 | seg = self._data[1][idx]
60 | return {'data': data, 'seg': seg}
61 |
62 |
63 | class OneDotDataLoader(SlimDataLoaderBase):
64 | def __init__(self, dataset_size, batch_size, coord_of_voxel):
65 | """
66 | creates both data and seg with only one voxel being = 1 and the rest zero. This will allow easy tracking of
67 | spatial transformations
68 | :param data_size: (b,c,x,y(,z))
69 | :param coord_of_voxel: (x, y(, z)))
70 | """
71 | super(OneDotDataLoader, self).__init__(None, batch_size, None)
72 |
73 | self.data = np.zeros(dataset_size)
74 | self.seg = np.zeros(dataset_size)
75 | self.data[:, :][coord_of_voxel] = 1
76 | self.seg[:, :][coord_of_voxel] = 1
77 |
78 | def generate_train_batch(self):
79 | idx = np.random.choice(self.data.shape[0], self.batch_size)
80 |
81 | data = self.data[idx]
82 | seg = self.data[idx]
83 | return {'data': data, 'seg': seg}
84 |
--------------------------------------------------------------------------------
/tests/test_random_crop.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import unittest
17 | import numpy as np
18 | from batchgenerators.augmentations.crop_and_pad_augmentations import random_crop
19 |
20 |
21 | class TestRandomCrop(unittest.TestCase):
22 |
23 | def setUp(self):
24 | np.random.seed(1234)
25 |
26 | def test_random_crop_3D(self):
27 | data = np.random.random((32, 4, 64, 56, 48))
28 | seg = np.ones(data.shape)
29 |
30 | d, s = random_crop(data, seg, 32, 0)
31 |
32 | self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32, 32), d.shape)), "data has unexpected return shape")
33 | self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32, 32), s.shape)), "seg has unexpected return shape")
34 |
35 | self.assertEqual(np.sum(s == 0), 0, "Zeros encountered in seg meaning that we did padding which should not have"
36 | " happened here!")
37 |
38 | def test_random_crop_2D(self):
39 | data = np.random.random((32, 4, 64, 56))
40 | seg = np.ones(data.shape)
41 |
42 | d, s = random_crop(data, seg, 32, 0)
43 |
44 | self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32), d.shape)), "data has unexpected return shape")
45 | self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32), s.shape)), "seg has unexpected return shape")
46 |
47 | self.assertEqual(np.sum(s == 0), 0, "Zeros encountered in seg meaning that we did padding which should not have"
48 | " happened here!")
49 |
50 | def test_random_crop_3D_from_List(self):
51 | data = [np.random.random((4, 64+i, 56+i, 48+i)) for i in range(32)]
52 | seg = [np.random.random((4, 64+i, 56+i, 48+i)) for i in range(32)]
53 |
54 | d, s = random_crop(data, seg, 32, 0)
55 |
56 | self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32), d.shape)), "data has unexpected return shape")
57 | self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32), s.shape)), "seg has unexpected return shape")
58 |
59 | self.assertEqual(np.sum(s == 0), 0, "Zeros encountered in seg meaning that we did padding which should not have"
60 | " happened here!")
61 |
62 | def test_random_crop_2D_from_List(self):
63 | data = [np.random.random((4, 64+i, 56+i)) for i in range(32)]
64 | seg = [np.random.random((4, 64+i, 56+i)) for i in range(32)]
65 |
66 | d, s = random_crop(data, seg, 32, 0)
67 |
68 | self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32), d.shape)), "data has unexpected return shape")
69 | self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32), s.shape)), "seg has unexpected return shape")
70 |
71 | self.assertEqual(np.sum(s == 0), 0, "Zeros encountered in seg meaning that we did padding which should not have"
72 | " happened here!")
73 |
74 |
75 | if __name__ == '__main__':
76 | unittest.main()
77 |
--------------------------------------------------------------------------------
/batchgenerators/transforms/abstract_transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import abc
17 | from warnings import warn
18 |
19 | import numpy as np
20 |
21 |
22 | class AbstractTransform(object):
23 | __metaclass__ = abc.ABCMeta
24 |
25 | @abc.abstractmethod
26 | def __call__(self, **data_dict):
27 | raise NotImplementedError("Abstract, so implement")
28 |
29 | def __repr__(self):
30 | ret_str = str(type(self).__name__) + "( " + ", ".join(
31 | [key + " = " + repr(val) for key, val in self.__dict__.items()]) + " )"
32 | return ret_str
33 |
34 |
35 | class RndTransform(AbstractTransform):
36 | """Applies a transformation with a specified probability
37 |
38 | Args:
39 | transform: The transformation (or composed transformation)
40 |
41 | prob: The probability with which to apply it
42 |
43 | alternative_transform: Will be applied if transform is not called. If transform alters for example the
44 | spatial dimension of the data, you need to compensate that with calling a dummy transformation that alters the
45 | spatial dimension in a similar way. We included this functionality because of SpatialTransform which has the
46 | ability to do cropping. If we want to not apply the spatial transformation we will still need to crop and
47 | therefore set the alternative_transform to an instance of RandomCropTransform of CenterCropTransform
48 | """
49 |
50 | def __init__(self, transform, prob=0.5, alternative_transform=None):
51 | warn("This is deprecated. All applicable transfroms now have a p_per_sample argument which allows "
52 | "batchgenerators to do or not do an augmentation on a per-sample basis instead of the entire batch",
53 | DeprecationWarning)
54 | self.alternative_transform = alternative_transform
55 | self.transform = transform
56 | self.prob = prob
57 |
58 | def __call__(self, **data_dict):
59 | rnd_val = np.random.uniform()
60 |
61 | if rnd_val < self.prob:
62 | return self.transform(**data_dict)
63 | else:
64 | if self.alternative_transform is not None:
65 | return self.alternative_transform(**data_dict)
66 | else:
67 | return data_dict
68 |
69 |
70 | class Compose(AbstractTransform):
71 | """Composes several transforms together.
72 |
73 | Args:
74 | transforms (list of ``Transform`` objects): list of transforms to compose.
75 |
76 | Example:
77 | >>> transforms.Compose([
78 | >>> transforms.CenterCrop(10),
79 | >>> transforms.ToTensor(),
80 | >>> ])
81 | """
82 |
83 | def __init__(self, transforms):
84 | self.transforms = transforms
85 |
86 | def __call__(self, **data_dict):
87 | for t in self.transforms:
88 | data_dict = t(**data_dict)
89 | return data_dict
90 |
91 | def __repr__(self):
92 | return str(type(self).__name__) + " ( " + repr(self.transforms) + " )"
93 |
--------------------------------------------------------------------------------
/tests/test_resample_augmentations.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import unittest
17 | import numpy as np
18 | from batchgenerators.augmentations.resample_augmentations import augment_linear_downsampling_scipy
19 |
20 |
21 | class TestAugmentResample(unittest.TestCase):
22 |
23 | def setUp(self):
24 | np.random.seed(1234)
25 | self.data_3D = np.random.random((2, 64, 56, 48))
26 | self.data_2D = np.random.random((2, 64, 56))
27 |
28 | self.data_3D_unique = np.reshape(range(2 * 64 * 56 * 48), newshape=(2, 64, 56, 48))
29 | self.data_2D_unique = np.reshape(range(2 * 64 * 56), newshape=(2, 64, 56))
30 |
31 | self.d_3D = augment_linear_downsampling_scipy(np.copy(self.data_3D), zoom_range=[0.5, 1.5], per_channel=False)
32 | self.d_2D = augment_linear_downsampling_scipy(np.copy(self.data_2D), zoom_range=[0.5, 1.5], per_channel=False)
33 |
34 | self.d_3D_channel = augment_linear_downsampling_scipy(np.copy(self.data_3D), zoom_range=[0.5, 1.5],
35 | per_channel=False, channels=[0])
36 | self.d_2D_channel = augment_linear_downsampling_scipy(np.copy(self.data_2D), zoom_range=[0.5, 1.5],
37 | per_channel=False, channels=[0])
38 |
39 | self.zoom_factor = 0.5
40 | self.d_3D_upsample = augment_linear_downsampling_scipy(np.copy(self.data_3D_unique),
41 | zoom_range=[self.zoom_factor, self.zoom_factor],
42 | per_channel=False, order_downsample=0)
43 | self.d_2D_upsample = augment_linear_downsampling_scipy(np.copy(self.data_2D_unique),
44 | zoom_range=[self.zoom_factor, self.zoom_factor],
45 | per_channel=False, order_downsample=0)
46 |
47 | def test_augment_resample(self):
48 | self.assertTrue(self.data_3D.shape == self.d_3D.shape,
49 | "shape of transformed data not the same as original one (3D)")
50 | self.assertTrue(self.data_2D.shape == self.d_2D.shape,
51 | "shape of transformed data not the same as original one (2D)")
52 |
53 | def test_augment_resample_upsample(self):
54 | self.assertTrue(int(len(np.unique(self.data_3D_unique))*pow(self.zoom_factor, 3)) == len(np.unique(self.d_3D_upsample)),
55 | "number of unique values after resampling is not correct")
56 |
57 | def test_augment_resample_channel(self):
58 | np.testing.assert_array_equal(self.d_3D_channel[1], self.data_3D[1],
59 | "channel that should not be augmented is changed (3D)")
60 | np.testing.assert_array_equal(self.d_2D_channel[1], self.data_2D[1],
61 | "channel that should not be augmented is changed (2D)")
62 |
63 | self.assertFalse(np.all(self.d_3D_channel[0] == self.data_3D[0]),
64 | "channel that should be augmented is not changed (3D)")
65 | self.assertFalse(np.all(self.d_2D_channel[0] == self.data_2D[0]),
66 | "channel that should be augmented is not changed (2D)")
67 |
68 |
69 | if __name__ == '__main__':
70 | unittest.main()
71 |
--------------------------------------------------------------------------------
/batchgenerators/augmentations/normalizations.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import numpy as np
17 |
18 |
19 | def range_normalization(data, rnge=(0, 1), per_channel=True, eps=1e-8):
20 | data_normalized = np.zeros(data.shape, dtype=data.dtype)
21 | for b in range(data.shape[0]):
22 | if per_channel:
23 | for c in range(data.shape[1]):
24 | data_normalized[b, c] = min_max_normalization(data[b, c], eps)
25 | else:
26 | data_normalized[b] = min_max_normalization(data[b], eps)
27 |
28 | data_normalized *= (rnge[1] - rnge[0])
29 | data_normalized += rnge[0]
30 | return data_normalized
31 |
32 |
33 | def min_max_normalization(data, eps):
34 | mn = data.min()
35 | mx = data.max()
36 | data_normalized = data - mn
37 | old_range = mx - mn + eps
38 | data_normalized /= old_range
39 |
40 | return data_normalized
41 |
42 | def zero_mean_unit_variance_normalization(data, per_channel=True, epsilon=1e-8):
43 | data_normalized = np.zeros(data.shape, dtype=data.dtype)
44 | for b in range(data.shape[0]):
45 | if per_channel:
46 | for c in range(data.shape[1]):
47 | mean = data[b, c].mean()
48 | std = data[b, c].std() + epsilon
49 | data_normalized[b, c] = (data[b, c] - mean) / std
50 | else:
51 | mean = data[b].mean()
52 | std = data[b].std() + epsilon
53 | data_normalized[b] = (data[b] - mean) / std
54 | return data_normalized
55 |
56 |
57 | def mean_std_normalization(data, mean, std, per_channel=True):
58 | data_normalized = np.zeros(data.shape, dtype=data.dtype)
59 | if isinstance(data, np.ndarray):
60 | data_shape = tuple(list(data.shape))
61 | elif isinstance(data, (list, tuple)):
62 | assert len(data) > 0 and isinstance(data[0], np.ndarray)
63 | data_shape = [len(data)] + list(data[0].shape)
64 | else:
65 | raise TypeError("Data has to be either a numpy array or a list")
66 |
67 | if per_channel and isinstance(mean, float) and isinstance(std, float):
68 | mean = [mean] * data_shape[1]
69 | std = [std] * data_shape[1]
70 | elif per_channel and isinstance(mean, (tuple, list, np.ndarray)):
71 | assert len(mean) == data_shape[1]
72 | elif per_channel and isinstance(std, (tuple, list, np.ndarray)):
73 | assert len(std) == data_shape[1]
74 |
75 | for b in range(data_shape[0]):
76 | if per_channel:
77 | for c in range(data_shape[1]):
78 | data_normalized[b][c] = (data[b][c] - mean[c]) / std[c]
79 | else:
80 | data_normalized[b] = (data[b] - mean) / std
81 | return data_normalized
82 |
83 |
84 | def cut_off_outliers(data, percentile_lower=0.2, percentile_upper=99.8, per_channel=False):
85 | for b in range(len(data)):
86 | if not per_channel:
87 | cut_off_lower = np.percentile(data[b], percentile_lower)
88 | cut_off_upper = np.percentile(data[b], percentile_upper)
89 | data[b][data[b] < cut_off_lower] = cut_off_lower
90 | data[b][data[b] > cut_off_upper] = cut_off_upper
91 | else:
92 | for c in range(data.shape[1]):
93 | cut_off_lower = np.percentile(data[b, c], percentile_lower)
94 | cut_off_upper = np.percentile(data[b, c], percentile_upper)
95 | data[b, c][data[b, c] < cut_off_lower] = cut_off_lower
96 | data[b, c][data[b, c] > cut_off_upper] = cut_off_upper
97 | return data
98 |
--------------------------------------------------------------------------------
/tests/test_augment_zoom.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import unittest
17 | import numpy as np
18 | from batchgenerators.augmentations.spatial_transformations import augment_zoom
19 |
20 |
21 | class TestAugmentZoom(unittest.TestCase):
22 |
23 | def setUp(self):
24 | np.random.seed(1234)
25 | self.data3D = np.zeros((2, 64, 56, 48))
26 | self.data3D[:, 21:41, 12:32, 13:33] = 1
27 | self.seg3D = self.data3D
28 |
29 | self.zoom_factors = 2
30 | self.d3D, self.s3D = augment_zoom(self.data3D, self.seg3D, zoom_factors=self.zoom_factors, order=0, order_seg=0)
31 |
32 | self.data2D = np.zeros((2, 64, 56))
33 | self.data2D[:, 21:41, 12:32] = 1
34 | self.seg2D = self.data2D
35 | self.d2D, self.s2D = augment_zoom(self.data2D, self.seg2D, zoom_factors=self.zoom_factors, order=0, order_seg=0)
36 |
37 | def test_augment_zoom_3D_dimensions(self):
38 | np.testing.assert_array_equal(self.zoom_factors * np.array(self.data3D.shape[1:]), np.array(self.d3D.shape[1:]), "image has unexpected return shape")
39 | self.assertTrue(self.data3D.shape[0] == self.d3D.shape[0], "color has unexpected return shape")
40 | self.assertTrue(self.seg3D.shape[0] == self.s3D.shape[0], "seg color channel has unexpected return shape")
41 | np.testing.assert_array_equal(self.zoom_factors * np.array(self.seg3D.shape[1:]), np.array(self.s3D.shape[1:]), "seg has unexpected return shape")
42 |
43 | def test_augment_zoom_3D_values(self):
44 | self.assertTrue(self.zoom_factors ** 3 * sum(self.data3D.flatten()) == sum(self.d3D.flatten()), "image has unexpected values inside")
45 | self.assertTrue(self.zoom_factors ** 3 * sum(self.seg3D.flatten()) == sum(self.s3D.flatten()), "segmentation has unexpected values inside")
46 | self.assertTrue(np.all(self.d3D[:, 42:82, 24:64, 26:66].flatten()), "image data is not zoomed correctly")
47 | idx = np.where(1 - self.d3D)
48 | tmp = self.d3D[idx]
49 | self.assertFalse(np.all(tmp), "image has unexpected values outside")
50 | idx = np.where(1 - self.s3D)
51 | tmp = self.s3D[idx]
52 | self.assertFalse(np.all(tmp), "segmentation has unexpected values outside")
53 |
54 | def test_augment_zoom_2D_dimensions(self):
55 | np.testing.assert_array_equal(self.zoom_factors * np.array(self.data2D.shape[1:]), np.array(self.d2D.shape[1:]), "image has unexpected return shape")
56 | self.assertTrue(self.data2D.shape[0] == self.d2D.shape[0], "color has unexpected return shape")
57 | self.assertTrue(self.seg2D.shape[0] == self.s2D.shape[0], "seg color channel has unexpected return shape")
58 | np.testing.assert_array_equal(self.zoom_factors * np.array(self.seg2D.shape[1:]), np.array(self.s2D.shape[1:]), "seg has unexpected return shape")
59 |
60 | def test_augment_zoom_2D_values(self):
61 | self.assertTrue(self.zoom_factors ** 2 * sum(self.data2D.flatten()) == sum(self.d2D.flatten()), "image has unexpected values inside")
62 | self.assertTrue(self.zoom_factors ** 2 * sum(self.seg2D.flatten()) == sum(self.s2D.flatten()), "segmentation has unexpected values inside")
63 | self.assertTrue(np.all(self.d2D[:, 42:82, 24:64].flatten()), "image data is not zoomed correctly")
64 | idx = np.where(1 - self.d2D)
65 | tmp = self.d2D[idx]
66 | self.assertFalse(np.all(tmp), "image has unexpected values outside")
67 | idx = np.where(1 - self.s2D)
68 | tmp = self.s2D[idx]
69 | self.assertFalse(np.all(tmp), "segmentation has unexpected values outside")
70 |
71 |
72 | if __name__ == '__main__':
73 | unittest.main()
74 |
--------------------------------------------------------------------------------
/batchgenerators/augmentations/resample_augmentations.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from builtins import range
17 | import numpy as np
18 | import random
19 | from skimage.transform import resize
20 | from batchgenerators.augmentations.utils import uniform
21 |
22 |
23 | def augment_linear_downsampling_scipy(data_sample, zoom_range=(0.5, 1), per_channel=True, p_per_channel=1,
24 | channels=None, order_downsample=1, order_upsample=0, ignore_axes=None):
25 | '''
26 | Downsamples each sample (linearly) by a random factor and upsamples to original resolution again (nearest neighbor)
27 |
28 | Info:
29 | * Uses scipy zoom for resampling. A bit faster than nilearn.
30 | * Resamples all dimensions (channels, x, y, z) with same downsampling factor (like isotropic=True from
31 | linear_downsampling_generator_nilearn)
32 |
33 | Args:
34 | zoom_range: can be either tuple/list/np.ndarray or tuple of tuple. If tuple/list/np.ndarray, then the zoom
35 | factor will be sampled from zoom_range[0], zoom_range[1] (zoom < 0 = downsampling!). If tuple of tuple then
36 | each inner tuple will give a sampling interval for each axis (allows for different range of zoom values for
37 | each axis
38 |
39 | p_per_channel: probability for downsampling/upsampling a channel
40 |
41 | per_channel (bool): whether to draw a new zoom_factor for each channel or keep one for all channels
42 |
43 | channels (list, tuple): if None then all channels can be augmented. If list then only the channel indices can
44 | be augmented (but may not always be depending on p_per_channel)
45 |
46 | order_downsample:
47 |
48 | order_upsample:
49 |
50 | ignore_axes: tuple/list
51 |
52 | '''
53 | if not isinstance(zoom_range, (list, tuple, np.ndarray)):
54 | zoom_range = [zoom_range]
55 |
56 | shp = np.array(data_sample.shape[1:])
57 | dim = len(shp)
58 |
59 | if not per_channel:
60 | if isinstance(zoom_range[0], (tuple, list, np.ndarray)):
61 | assert len(zoom_range) == dim
62 | zoom = np.array([uniform(i[0], i[1]) for i in zoom_range])
63 | else:
64 | zoom = uniform(zoom_range[0], zoom_range[1])
65 |
66 | target_shape = np.round(shp * zoom).astype(int)
67 |
68 | if ignore_axes is not None:
69 | for i in ignore_axes:
70 | target_shape[i] = shp[i]
71 |
72 | if channels is None:
73 | channels = list(range(data_sample.shape[0]))
74 |
75 | for c in channels:
76 | if np.random.uniform() < p_per_channel:
77 | if per_channel:
78 | if isinstance(zoom_range[0], (tuple, list, np.ndarray)):
79 | assert len(zoom_range) == dim
80 | zoom = np.array([uniform(i[0], i[1]) for i in zoom_range])
81 | else:
82 | zoom = uniform(zoom_range[0], zoom_range[1])
83 |
84 | target_shape = np.round(shp * zoom).astype(int)
85 | if ignore_axes is not None:
86 | for i in ignore_axes:
87 | target_shape[i] = shp[i]
88 |
89 | downsampled = resize(data_sample[c].astype(float), target_shape, order=order_downsample, mode='edge',
90 | anti_aliasing=False)
91 | data_sample[c] = resize(downsampled, shp, order=order_upsample, mode='edge',
92 | anti_aliasing=False)
93 |
94 | return data_sample
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
--------------------------------------------------------------------------------
/batchgenerators/examples/multithreaded_dataloading.py:
--------------------------------------------------------------------------------
1 | from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
2 | from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
3 | import numpy as np
4 |
5 |
6 | class DummyDL(SlimDataLoaderBase):
7 | def __init__(self, num_threads_in_mt=8):
8 | super(DummyDL, self).__init__(None, None, num_threads_in_mt)
9 | self._data = list(range(100))
10 | self.current_position = 0
11 | self.was_initialized = False
12 |
13 | def reset(self):
14 | self.current_position = self.thread_id
15 | self.was_initialized = True
16 |
17 | def generate_train_batch(self):
18 | if not self.was_initialized:
19 | self.reset()
20 | idx = self.current_position
21 | if idx < len(self._data):
22 | self.current_position = idx + self.number_of_threads_in_multithreaded
23 | return self._data[idx]
24 | else:
25 | self.reset()
26 | raise StopIteration
27 |
28 |
29 | class DummyDLWithShuffle(DummyDL):
30 | def __init__(self, num_threads_in_mt=8):
31 | super(DummyDLWithShuffle, self).__init__(num_threads_in_mt)
32 | self.num_restarted = 0
33 | self.data_order = np.arange(len(self._data))
34 |
35 | def reset(self):
36 | super(DummyDLWithShuffle, self).reset()
37 | rs = np.random.RandomState(self.num_restarted)
38 | rs.shuffle(self.data_order)
39 | self.num_restarted = self.num_restarted + 1
40 |
41 | def generate_train_batch(self):
42 | if not self.was_initialized:
43 | self.reset()
44 | idx = self.current_position
45 | if idx < len(self._data):
46 | self.current_position = idx + self.number_of_threads_in_multithreaded
47 | return self._data[self.data_order[idx]]
48 | else:
49 | self.reset()
50 | raise StopIteration
51 |
52 |
53 | if __name__ == "__main__":
54 | """
55 | Why is is so hard to iterate only once over my entire training dataset when MultiThreadedAugmenter is used?
56 | This is because MultiThreadedAugmenter will spawn num_threads workers and each worker will hold a copy of the entire
57 | pipeline, including the DataLoader. Therefore, if your DataLoader is configured to run over the training data once, but
58 | you have 8 threads then what you will be getting from the MultiThreadedAugmenter is an iteration over eight times your
59 | training dataset"""
60 |
61 | """
62 | HELP I want to iterate over all my training data once per epoch.
63 | Say no more. We go your back. Here is a simple example how you can do that.
64 |
65 | We create a dummy dataloader that has the numbers of 0 to 99 in its _data variable. In the MultiThreadedAugmenter, each
66 | DataLoader will know what thread ID it has. We use that information to iterate over the training data. Since there are
67 | 3 threads, each individual dataloader must return every third item (and start in a different position)
68 | """
69 |
70 | dl = DummyDL(num_threads_in_mt=3)
71 | mt = MultiThreadedAugmenter(dl, None, 3, 1, None)
72 |
73 | for i in mt:
74 | print(i)
75 |
76 |
77 | """
78 | You can run the mt as often as you want because the DataLoader it will reset itself before raising StopIteration
79 | """
80 | for i in mt:
81 | print(i)
82 |
83 | for i in mt:
84 | print(i)
85 |
86 |
87 | """
88 | But wait. Isn't it suboptimal to iterate over training data always in the same order? Correct. Try this:
89 | """
90 |
91 | dl = DummyDLWithShuffle(num_threads_in_mt=3)
92 | mt = MultiThreadedAugmenter(dl, None, 3, 1, None)
93 |
94 | batches = []
95 | for i in mt:
96 | batches.append(i)
97 | print(batches)
98 | assert len(np.unique(batches)) == 100 and len(batches) == 100 # assert makes sure we got what we wanted
99 |
100 | """
101 | Once again you can run that as often as you want
102 | """
103 |
104 | batches = []
105 | for i in mt:
106 | batches.append(i)
107 | print(batches)
108 | assert len(np.unique(batches)) == 100 and len(batches) == 100
109 |
110 | batches = []
111 | for i in mt:
112 | batches.append(i)
113 | print(batches)
114 | assert len(np.unique(batches)) == 100 and len(batches) == 100
115 |
--------------------------------------------------------------------------------
/batchgenerators/augmentations/noise_augmentations.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import random
17 | from typing import Tuple
18 |
19 | import numpy as np
20 | from batchgenerators.augmentations.utils import get_range_val, mask_random_squares
21 | from builtins import range
22 | from scipy.ndimage import gaussian_filter
23 |
24 |
25 | def augment_rician_noise(data_sample, noise_variance=(0, 0.1)):
26 | variance = random.uniform(noise_variance[0], noise_variance[1])
27 | data_sample = np.sqrt(
28 | (data_sample + np.random.normal(0.0, variance, size=data_sample.shape)) ** 2 +
29 | np.random.normal(0.0, variance, size=data_sample.shape) ** 2) * np.sign(data_sample)
30 | return data_sample
31 |
32 |
33 | def augment_gaussian_noise(data_sample: np.ndarray, noise_variance: Tuple[float, float] = (0, 0.1),
34 | p_per_channel: float = 1, per_channel: bool = False) -> np.ndarray:
35 | if not per_channel:
36 | variance = noise_variance[0] if noise_variance[0] == noise_variance[1] else \
37 | random.uniform(noise_variance[0], noise_variance[1])
38 | else:
39 | variance = None
40 | for c in range(data_sample.shape[0]):
41 | if np.random.uniform() < p_per_channel:
42 | # lol good luck reading this
43 | variance_here = variance if variance is not None else \
44 | noise_variance[0] if noise_variance[0] == noise_variance[1] else \
45 | random.uniform(noise_variance[0], noise_variance[1])
46 | # bug fixed: https://github.com/MIC-DKFZ/batchgenerators/issues/86
47 | data_sample[c] = data_sample[c] + np.random.normal(0.0, variance_here, size=data_sample[c].shape)
48 | return data_sample
49 |
50 |
51 | def augment_gaussian_blur(data_sample: np.ndarray, sigma_range: Tuple[float, float], per_channel: bool = True,
52 | p_per_channel: float = 1, different_sigma_per_axis: bool = False,
53 | p_isotropic: float = 0) -> np.ndarray:
54 | if not per_channel:
55 | # Godzilla Had a Stroke Trying to Read This and F***ing Died
56 | # https://i.kym-cdn.com/entries/icons/original/000/034/623/Untitled-3.png
57 | sigma = get_range_val(sigma_range) if ((not different_sigma_per_axis) or
58 | ((np.random.uniform() < p_isotropic) and
59 | different_sigma_per_axis)) \
60 | else [get_range_val(sigma_range) for _ in data_sample.shape[1:]]
61 | else:
62 | sigma = None
63 | for c in range(data_sample.shape[0]):
64 | if np.random.uniform() <= p_per_channel:
65 | if per_channel:
66 | sigma = get_range_val(sigma_range) if ((not different_sigma_per_axis) or
67 | ((np.random.uniform() < p_isotropic) and
68 | different_sigma_per_axis)) \
69 | else [get_range_val(sigma_range) for _ in data_sample.shape[1:]]
70 | data_sample[c] = gaussian_filter(data_sample[c], sigma, order=0)
71 | return data_sample
72 |
73 |
74 | def augment_blank_square_noise(data_sample, square_size, n_squares, noise_val=(0, 0), channel_wise_n_val=False,
75 | square_pos=None):
76 | # rnd_n_val = get_range_val(noise_val)
77 | rnd_square_size = get_range_val(square_size)
78 | rnd_n_squares = get_range_val(n_squares)
79 |
80 | data_sample = mask_random_squares(data_sample, square_size=rnd_square_size, n_squares=rnd_n_squares,
81 | n_val=noise_val, channel_wise_n_val=channel_wise_n_val,
82 | square_pos=square_pos)
83 | return data_sample
84 |
--------------------------------------------------------------------------------
/batchgenerators/transforms/resample_transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from warnings import warn
17 |
18 | from batchgenerators.transforms.abstract_transforms import AbstractTransform
19 | from batchgenerators.augmentations.resample_augmentations import augment_linear_downsampling_scipy
20 | import numpy as np
21 |
22 |
23 | class SimulateLowResolutionTransform(AbstractTransform):
24 | """Downsamples each sample (linearly) by a random factor and upsamples to original resolution again
25 | (nearest neighbor)
26 |
27 | Info:
28 | * Uses scipy zoom for resampling.
29 | * Resamples all dimensions (channels, x, y, z) with same downsampling factor (like isotropic=True from
30 | linear_downsampling_generator_nilearn)
31 |
32 | Args:
33 | zoom_range: can be either tuple/list/np.ndarray or tuple of tuple. If tuple/list/np.ndarray, then the zoom
34 | factor will be sampled from zoom_range[0], zoom_range[1] (zoom < 0 = downsampling!). If tuple of tuple then
35 | each inner tuple will give a sampling interval for each axis (allows for different range of zoom values for
36 | each axis
37 |
38 | p_per_channel:
39 |
40 | per_channel (bool): whether to draw a new zoom_factor for each channel or keep one for all channels
41 |
42 | channels (list, tuple): if None then all channels can be augmented. If list then only the channel indices can
43 | be augmented (but may not always be depending on p_per_channel)
44 |
45 | order_downsample:
46 |
47 | order_upsample:
48 | """
49 |
50 | def __init__(self, zoom_range=(0.5, 1), per_channel=False, p_per_channel=1,
51 | channels=None, order_downsample=1, order_upsample=0, data_key="data", p_per_sample=1,
52 | ignore_axes=None):
53 | self.order_upsample = order_upsample
54 | self.order_downsample = order_downsample
55 | self.channels = channels
56 | self.per_channel = per_channel
57 | self.p_per_channel = p_per_channel
58 | self.p_per_sample = p_per_sample
59 | self.data_key = data_key
60 | self.zoom_range = zoom_range
61 | self.ignore_axes = ignore_axes
62 |
63 | def __call__(self, **data_dict):
64 | for b in range(len(data_dict[self.data_key])):
65 | if np.random.uniform() < self.p_per_sample:
66 | data_dict[self.data_key][b] = augment_linear_downsampling_scipy(data_dict[self.data_key][b],
67 | zoom_range=self.zoom_range,
68 | per_channel=self.per_channel,
69 | p_per_channel=self.p_per_channel,
70 | channels=self.channels,
71 | order_downsample=self.order_downsample,
72 | order_upsample=self.order_upsample,
73 | ignore_axes=self.ignore_axes)
74 | return data_dict
75 |
76 |
77 | class ResampleTransform(SimulateLowResolutionTransform):
78 | def __init__(self, zoom_range=(0.5, 1), per_channel=False, p_per_channel=1,
79 | channels=None, order_downsample=1, order_upsample=0, data_key="data", p_per_sample=1):
80 | warn("This class is deprecated. It was renamed to SimulateLowResolutionTransform. Please change your code",
81 | DeprecationWarning)
82 | super(ResampleTransform, self).__init__(zoom_range, per_channel, p_per_channel,
83 | channels, order_downsample, order_upsample, data_key, p_per_sample)
84 |
--------------------------------------------------------------------------------
/batchgenerators/transforms/sample_normalization_transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from batchgenerators.augmentations.normalizations import cut_off_outliers, mean_std_normalization, range_normalization, \
17 | zero_mean_unit_variance_normalization
18 | from batchgenerators.transforms.abstract_transforms import AbstractTransform
19 |
20 |
21 | class RangeTransform(AbstractTransform):
22 | '''Rescales data into the specified range
23 |
24 | Args:
25 | rnge (tuple of float): The range to which the data is scaled
26 |
27 | per_channel (bool): determines whether the min and max values used for the rescaling are computed over the whole
28 | sample or separately for each channel
29 |
30 | '''
31 |
32 | def __init__(self, rnge=(0, 1), per_channel=True, data_key="data", label_key="seg"):
33 | self.data_key = data_key
34 | self.label_key = label_key
35 | self.per_channel = per_channel
36 | self.rnge = rnge
37 |
38 | def __call__(self, **data_dict):
39 | data_dict[self.data_key] = range_normalization(data_dict[self.data_key], self.rnge,
40 | per_channel=self.per_channel)
41 | return data_dict
42 |
43 |
44 | class CutOffOutliersTransform(AbstractTransform):
45 | """ Removes outliers from data
46 |
47 | Args:
48 | percentile_lower (float between 0 and 100): Lower cutoff percentile
49 |
50 | percentile_upper (float between 0 and 100): Upper cutoff percentile
51 |
52 | per_channel (bool): determines whether percentiles are computed for each color channel separately
53 | """
54 |
55 | def __init__(self, percentile_lower=0.2, percentile_upper=99.8, per_channel=False, data_key="data",
56 | label_key="seg"):
57 | self.data_key = data_key
58 | self.label_key = label_key
59 | self.per_channel = per_channel
60 | self.percentile_upper = percentile_upper
61 | self.percentile_lower = percentile_lower
62 |
63 | def __call__(self, **data_dict):
64 | data_dict[self.data_key] = cut_off_outliers(data_dict[self.data_key], self.percentile_lower,
65 | self.percentile_upper,
66 | per_channel=self.per_channel)
67 | return data_dict
68 |
69 |
70 | class ZeroMeanUnitVarianceTransform(AbstractTransform):
71 | """ Zero mean unit variance transform
72 |
73 | Args:
74 | per_channel (bool): determines whether mean and std are computed for and applied to each color channel
75 | separately
76 |
77 | epsilon (float): prevent nan if std is zero, keep at 1e-7
78 | """
79 |
80 | def __init__(self, per_channel=True, epsilon=1e-7, data_key="data", label_key="seg"):
81 | self.data_key = data_key
82 | self.label_key = label_key
83 | self.epsilon = epsilon
84 | self.per_channel = per_channel
85 |
86 | def __call__(self, **data_dict):
87 | data_dict[self.data_key] = zero_mean_unit_variance_normalization(data_dict[self.data_key], self.per_channel,
88 | self.epsilon)
89 | return data_dict
90 |
91 |
92 | class MeanStdNormalizationTransform(AbstractTransform):
93 | """ Zero mean unit variance transform
94 |
95 | Args:
96 | per_channel (bool): determines whether mean and std are computed for and applied to each color channel
97 | separately
98 |
99 | epsilon (float): prevent nan if std is zero, keep at 1e-7
100 | """
101 |
102 | def __init__(self, mean, std, per_channel=True, data_key="data", label_key="seg"):
103 | self.data_key = data_key
104 | self.label_key = label_key
105 | self.std = std
106 | self.mean = mean
107 | self.per_channel = per_channel
108 |
109 | def __call__(self, **data_dict):
110 | data_dict[self.data_key] = mean_std_normalization(data_dict[self.data_key], self.mean, self.std,
111 | self.per_channel)
112 | return data_dict
113 |
--------------------------------------------------------------------------------
/batchgenerators/datasets/cifar.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import tarfile
4 | from urllib.request import urlretrieve
5 |
6 | import numpy as np
7 | from batchgenerators.dataloading.data_loader import DataLoader
8 | from batchgenerators.dataloading.dataset import Dataset
9 | from batchgenerators.utilities.file_and_folder_operations import join
10 |
11 |
12 | def unpickle(file):
13 | '''
14 | taken from http://www.cs.toronto.edu/~kriz/cifar.html
15 | :param file:
16 | :return:
17 | '''
18 | import pickle
19 |
20 | with open(file, 'rb') as fo:
21 | dc = pickle.load(fo, encoding='bytes')
22 | return dc
23 |
24 |
25 | def maybe_download_and_prepare_cifar(target_dir, cifar=10):
26 | '''
27 | Checks if cifar is already present in target_dir and downloads it if not.
28 | CIFAR comes in 5 batches that need to be unpickled. What a mess.
29 | We stack all 5 batches together to one single npy array. No idea why they are being so complicated
30 | :param target_dir:
31 | :return:
32 | '''
33 | if not os.path.isfile(os.path.join(target_dir, 'cifar%d_test_data.npz' % cifar)) or not \
34 | os.path.isfile(os.path.join(target_dir, 'cifar%d_training_data.npz' % cifar)):
35 | print('downloading CIFAR%d...' % cifar)
36 | urlretrieve('http://www.cs.toronto.edu/~kriz/cifar-%d-python.tar.gz' % cifar, join(target_dir, 'cifar-%d-python.tar.gz' % cifar))
37 |
38 | tar = tarfile.open(os.path.join(target_dir, 'cifar-%d-python.tar.gz' % cifar), "r:gz")
39 | tar.extractall(path=target_dir)
40 | tar.close()
41 |
42 | data = []
43 | labels = []
44 | filenames = []
45 |
46 | for batch in range(1, 6):
47 | loaded = unpickle(os.path.join(target_dir, 'cifar-%d-batches-py' % cifar, 'data_batch_%d' % batch))
48 | data.append(loaded[b'data'].reshape((loaded[b'data'].shape[0], 3, 32, 32)).astype(np.uint8))
49 | labels += [int(i) for i in loaded[b'labels']]
50 | filenames += [str(i) for i in loaded[b'filenames']]
51 |
52 | data = np.vstack(data)
53 | labels = np.array(labels)
54 | filenames = np.array(filenames)
55 |
56 | np.savez_compressed(os.path.join(target_dir, 'cifar%d_training_data.npz' % cifar), data=data, labels=labels,
57 | filenames=filenames)
58 |
59 | test = unpickle(os.path.join(target_dir, 'cifar-%d-batches-py' % cifar, 'test_batch'))
60 | data = test[b'data'].reshape((test[b'data'].shape[0], 3, 32, 32)).astype(np.uint8)
61 | labels = [int(i) for i in test[b'labels']]
62 | filenames = [i for i in test[b'filenames']]
63 |
64 | np.savez_compressed(os.path.join(target_dir, 'cifar%d_test_data.npz' % cifar), data=data, labels=labels,
65 | filenames=filenames)
66 |
67 | # clean up
68 | shutil.rmtree(os.path.join(target_dir, 'cifar-%d-batches-py' % cifar))
69 | os.remove(os.path.join(target_dir, 'cifar-%d-python.tar.gz' % cifar))
70 |
71 |
72 | class CifarDataset(Dataset):
73 | def __init__(self, dataset_directory, train=True, transform=None, cifar=10):
74 | super(CifarDataset, self).__init__()
75 | self.transform = transform
76 | maybe_download_and_prepare_cifar(dataset_directory)
77 |
78 | self.train = train
79 |
80 | # load appropriate data
81 | if train:
82 | fname = os.path.join(dataset_directory, 'cifar%d_training_data.npz' % cifar)
83 | else:
84 | fname = os.path.join(dataset_directory, 'cifar%d_test_data.npz' % cifar)
85 |
86 | dataset = np.load(fname)
87 |
88 | self.data = dataset['data']
89 | self.labels = dataset['labels']
90 | self.filenames = dataset['filenames']
91 |
92 | def __getitem__(self, item):
93 | data_dict = {'data': self.data[item:item+1].astype(np.float32), 'labels': self.labels[item], 'filenames': self.filenames[item]}
94 | if self.transform is not None:
95 | data_dict = self.transform(**data_dict)
96 | return data_dict
97 |
98 | def __len__(self):
99 | return len(self.data)
100 |
101 |
102 | class HighPerformanceCIFARLoader(DataLoader):
103 | def __init__(self, data, batch_size, num_threads_in_multithreaded, seed_for_shuffle=1, infinite=False,
104 | return_incomplete=False):
105 | super(HighPerformanceCIFARLoader, self).__init__(data, batch_size, num_threads_in_multithreaded,
106 | seed_for_shuffle, infinite=infinite,
107 | return_incomplete=return_incomplete)
108 | self.indices = np.arange(len(data[0]))
109 |
110 | def generate_train_batch(self):
111 | indices = self.get_indices()
112 |
113 | data = self._data[0][indices]
114 | labels = self._data[1][indices]
115 | filenames = self._data[2][indices]
116 |
117 | return {'data': data.astype(np.float32), 'labels': labels, 'filenames': filenames}
118 |
119 |
--------------------------------------------------------------------------------
/batchgenerators/utilities/file_and_folder_operations.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | import pickle
18 | import json
19 | from typing import List, Union, Optional
20 |
21 |
22 | def subdirs(folder: str, join: bool = True, prefix: Optional[str] = None, suffix: Optional[str] = None, sort: bool = True) -> List[str]:
23 | """
24 | Returns a list of subdirectories in a given folder, optionally filtering by prefix and suffix,
25 | and optionally sorting the results. Uses os.scandir for efficient directory traversal.
26 |
27 | Parameters:
28 | - folder: Path to the folder to list subdirectories from.
29 | - join: Whether to return full paths to subdirectories (if True) or just directory names (if False).
30 | - prefix: Only include subdirectories that start with this prefix (if provided).
31 | - suffix: Only include subdirectories that end with this suffix (if provided).
32 | - sort: Whether to sort the list of subdirectories alphabetically.
33 |
34 | Returns:
35 | - List of subdirectory paths (or names) meeting the specified criteria.
36 | """
37 | subdirectories = []
38 | with os.scandir(folder) as entries:
39 | for entry in entries:
40 | if entry.is_dir() and \
41 | (prefix is None or entry.name.startswith(prefix)) and \
42 | (suffix is None or entry.name.endswith(suffix)):
43 | dir_path = entry.path if join else entry.name
44 | subdirectories.append(dir_path)
45 |
46 | if sort:
47 | subdirectories.sort()
48 |
49 | return subdirectories
50 |
51 |
52 | def subfiles(folder: str, join: bool = True, prefix: Optional[str] = None, suffix: Optional[str] = None, sort: bool = True) -> List[str]:
53 | """
54 | Returns a list of files in a given folder, optionally filtering by prefix and suffix,
55 | and optionally sorting the results. Uses os.scandir for efficient directory traversal,
56 | making it suitable for network drives.
57 |
58 | Parameters:
59 | - folder: Path to the folder to list files from.
60 | - join: Whether to return full file paths (if True) or just file names (if False).
61 | - prefix: Only include files that start with this prefix (if provided).
62 | - suffix: Only include files that end with this suffix (if provided).
63 | - sort: Whether to sort the list of files alphabetically.
64 |
65 | Returns:
66 | - List of file paths (or names) meeting the specified criteria.
67 | """
68 | files = []
69 | with os.scandir(folder) as entries:
70 | for entry in entries:
71 | if entry.is_file() and \
72 | (prefix is None or entry.name.startswith(prefix)) and \
73 | (suffix is None or entry.name.endswith(suffix)):
74 | file_path = entry.path if join else entry.name
75 | files.append(file_path)
76 |
77 | if sort:
78 | files.sort()
79 |
80 | return files
81 |
82 |
83 | def nifti_files(folder: str, join: bool = True, sort: bool = True) -> List[str]:
84 | return subfiles(folder, join=join, sort=sort, suffix='.nii.gz')
85 |
86 |
87 | def maybe_mkdir_p(directory: str) -> None:
88 | os.makedirs(directory, exist_ok=True)
89 |
90 |
91 | def load_pickle(file: str, mode: str = 'rb'):
92 | with open(file, mode) as f:
93 | a = pickle.load(f)
94 | return a
95 |
96 |
97 | def write_pickle(obj, file: str, mode: str = 'wb') -> None:
98 | with open(file, mode) as f:
99 | pickle.dump(obj, f)
100 |
101 |
102 | def load_json(file: str):
103 | with open(file, 'r') as f:
104 | a = json.load(f)
105 | return a
106 |
107 |
108 | def save_json(obj, file: str, indent: int = 4, sort_keys: bool = True) -> None:
109 | with open(file, 'w') as f:
110 | json.dump(obj, f, sort_keys=sort_keys, indent=indent)
111 |
112 |
113 | def pardir(path: str):
114 | return os.path.join(path, os.pardir)
115 |
116 |
117 | def split_path(path: str) -> List[str]:
118 | """
119 | splits at each separator. This is different from os.path.split which only splits at last separator
120 | """
121 | return path.split(os.sep)
122 |
123 |
124 | # I'm tired of typing these out
125 | join = os.path.join
126 | isdir = os.path.isdir
127 | isfile = os.path.isfile
128 | listdir = os.listdir
129 | makedirs = maybe_mkdir_p
130 | os_split_path = os.path.split
131 |
132 | # I am tired of confusing those
133 | subfolders = subdirs
134 | save_pickle = write_pickle
135 | write_json = save_json
136 |
--------------------------------------------------------------------------------
/batchgenerators/examples/cifar10.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | import numpy as np
4 | import os
5 | from batchgenerators.dataloading.data_loader import DataLoaderFromDataset
6 | from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
7 | from batchgenerators.datasets.cifar import HighPerformanceCIFARLoader, CifarDataset
8 | from batchgenerators.transforms.abstract_transforms import Compose
9 | from batchgenerators.transforms.spatial_transforms import SpatialTransform
10 | from batchgenerators.transforms.utility_transforms import NumpyToTensor
11 | from torch._six import int_classes, string_classes, container_abcs
12 | from torch.utils.data.dataloader import numpy_type_map
13 |
14 | _use_shared_memory = False
15 |
16 |
17 | def default_collate(batch):
18 | r"""Puts each data field into a tensor with outer dimension batch size"""
19 |
20 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
21 | elem_type = type(batch[0])
22 | if isinstance(batch[0], torch.Tensor):
23 | out = None
24 | if _use_shared_memory:
25 | # If we're in a background process, concatenate directly into a
26 | # shared memory tensor to avoid an extra copy
27 | numel = sum([x.numel() for x in batch])
28 | storage = batch[0].storage()._new_shared(numel)
29 | out = batch[0].new(storage)
30 | return torch.stack(batch, 0, out=out)
31 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
32 | and elem_type.__name__ != 'string_':
33 | elem = batch[0]
34 | if elem_type.__name__ == 'ndarray':
35 | # array of string classes and object
36 | if re.search('[SaUO]', elem.dtype.str) is not None:
37 | raise TypeError(error_msg.format(elem.dtype))
38 |
39 | return torch.stack([torch.from_numpy(b) for b in batch], 0)
40 | if elem.shape == (): # scalars
41 | py_type = float if elem.dtype.name.startswith('float') else int
42 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
43 | elif isinstance(batch[0], int_classes):
44 | return torch.LongTensor(batch)
45 | elif isinstance(batch[0], float):
46 | return torch.DoubleTensor(batch)
47 | elif isinstance(batch[0], string_classes):
48 | return batch
49 | elif isinstance(batch[0], container_abcs.Mapping):
50 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
51 | elif isinstance(batch[0], container_abcs.Sequence):
52 | transposed = zip(*batch)
53 | return [default_collate(samples) for samples in transposed]
54 |
55 | raise TypeError((error_msg.format(type(batch[0]))))
56 |
57 |
58 | if __name__ == '__main__':
59 | ### current implementation of betchgenerators stuff for this script does not use _use_shared_memory!
60 |
61 | from time import time
62 | batch_size = 50
63 | num_workers = 8
64 | pin_memory = False
65 | num_epochs = 3
66 | dataset_dir = '/media/fabian/data/data/cifar10'
67 | numpy_to_tensor = NumpyToTensor(['data', 'labels'], cast_to=None)
68 | fname = os.path.join(dataset_dir, 'cifar10_training_data.npz')
69 | dataset = np.load(fname)
70 | cifar_dataset_as_arrays = (dataset['data'], dataset['labels'], dataset['filenames'])
71 | print('batch_size', batch_size)
72 | print('num_workers', num_workers)
73 | print('pin_memory', pin_memory)
74 | print('num_epochs', num_epochs)
75 |
76 | tr_transforms = [SpatialTransform((32, 32))] * 1 # SpatialTransform is computationally expensive and we need some
77 | # load on CPU so we just stack 5 of them on top of each other
78 | tr_transforms.append(numpy_to_tensor)
79 | tr_transforms = Compose(tr_transforms)
80 |
81 | cifar_dataset = CifarDataset(dataset_dir, train=True, transform=tr_transforms)
82 |
83 | dl = DataLoaderFromDataset(cifar_dataset, batch_size, num_workers, 1)
84 | mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, pin_memory)
85 |
86 | batches = 0
87 | for _ in mt:
88 | batches += 1
89 | assert len(_['data'].shape) == 4
90 |
91 | assert batches == len(cifar_dataset) / batch_size # this assertion only holds if len(datset) is divisible by
92 | # batch size
93 |
94 | start = time()
95 | for _ in range(num_epochs):
96 | batches = 0
97 | for _ in mt:
98 | batches += 1
99 | stop = time()
100 | print('batchgenerators took %03.4f seconds' % (stop - start))
101 |
102 | # The best I can do:
103 |
104 | dl = HighPerformanceCIFARLoader(cifar_dataset_as_arrays, batch_size, num_workers, 1) # this circumvents the
105 | # default_collate function, just to see if that is slowing things down
106 | mt = MultiThreadedAugmenter(dl, tr_transforms, num_workers, 1, None, pin_memory)
107 |
108 | batches = 0
109 | for _ in mt:
110 | batches += 1
111 | assert len(_['data'].shape) == 4
112 |
113 | assert batches == len(cifar_dataset_as_arrays[0]) / batch_size # this assertion only holds if len(datset) is
114 | # divisible by batch size
115 |
116 | start = time()
117 | for _ in range(num_epochs):
118 | batches = 0
119 | for _ in mt:
120 | batches += 1
121 | stop = time()
122 | print('high performance batchgenerators %03.4f seconds' % (stop - start))
123 |
124 |
125 | from torch.utils.data import DataLoader as TorchDataLoader
126 |
127 | trainloader = TorchDataLoader(cifar_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
128 | pin_memory=pin_memory, collate_fn=default_collate)
129 |
130 | batches = 0
131 | for _ in iter(trainloader):
132 | batches += 1
133 | assert len(_['data'].shape) == 4
134 |
135 | start = time()
136 | for _ in range(num_epochs):
137 | batches = 0
138 | for _ in trainloader:
139 | batches += 1
140 | stop = time()
141 | print('pytorch took %03.4f seconds' % (stop - start))
142 |
--------------------------------------------------------------------------------
/tests/test_spatial_transformations.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import unittest
17 | import numpy as np
18 |
19 | from batchgenerators.augmentations.spatial_transformations import augment_rot90, augment_resize, augment_transpose_axes
20 |
21 |
22 | class AugmentTransposeAxes(unittest.TestCase):
23 |
24 | def setUp(self):
25 | np.random.seed(123)
26 | self.data_3D = np.random.random((2, 4, 5, 6))
27 | self.seg_3D = np.random.random(self.data_3D.shape)
28 |
29 | def test_transpose_axes(self):
30 | n_iter = 1000
31 | tmp = 0
32 | for i in range(n_iter):
33 | data_out, seg_out = augment_transpose_axes(self.data_3D, self.seg_3D, axes=(1, 0))
34 |
35 | if np.array_equal(data_out, np.swapaxes(self.data_3D, 1, 2)):
36 | tmp += 1
37 | self.assertAlmostEqual(tmp, n_iter/2., delta=10)
38 |
39 |
40 | class AugmentResize(unittest.TestCase):
41 |
42 | def setUp(self):
43 | np.random.seed(123)
44 | self.data_3D = np.random.random((2, 12, 14, 31))
45 | self.seg_3D = np.random.random(self.data_3D.shape)
46 |
47 | def test_resize(self):
48 | data_resized, seg_resized = augment_resize(self.data_3D, self.seg_3D, target_size=15)
49 |
50 | mean_resized = float(np.mean(data_resized))
51 | mean_original = float(np.mean(self.data_3D))
52 |
53 | self.assertAlmostEqual(mean_original, mean_resized, places=2)
54 |
55 | self.assertTrue(all((data_resized.shape[i] == 15 and seg_resized.shape[i] == 15) for i in
56 | range(1, len(data_resized.shape))))
57 |
58 | def test_resize2(self):
59 | data_resized, seg_resized = augment_resize(self.data_3D, self.seg_3D, target_size=(7, 5, 6))
60 |
61 | mean_resized = float(np.mean(data_resized))
62 | mean_original = float(np.mean(self.data_3D))
63 |
64 | self.assertAlmostEqual(mean_original, mean_resized, places=2)
65 |
66 | self.assertTrue(all([i == j for i, j in zip(data_resized.shape[1:], (7, 5, 6))]))
67 | self.assertTrue(all([i == j for i, j in zip(seg_resized.shape[1:], (7, 5, 6))]))
68 |
69 |
70 | class AugmentRot90(unittest.TestCase):
71 |
72 | def setUp(self):
73 | np.random.seed(123)
74 | self.data_3D = np.random.random((2, 4, 5, 6))
75 | self.seg_3D = np.random.random(self.data_3D.shape)
76 | self.num_rot = [1]
77 |
78 | def test_rotation_checkerboard(self):
79 | data_2d_checkerboard = np.zeros((1, 2, 2))
80 | data_2d_checkerboard[0, 0, 0] = 1
81 | data_2d_checkerboard[0, 1, 1] = 1
82 |
83 | data_rotated_list = []
84 | n_iter = 1000
85 | for i in range(n_iter):
86 | d_r, _ = augment_rot90(np.copy(data_2d_checkerboard), None, num_rot=[4,1], axes=[0, 1])
87 | data_rotated_list.append(d_r)
88 |
89 | data_rotated_np = np.array(data_rotated_list)
90 | sum_data_list = np.sum(data_rotated_np, axis=0)
91 | a = np.unique(sum_data_list)
92 | self.assertAlmostEqual(a[0], n_iter/2, delta=20)
93 | self.assertTrue(len(a) == 2)
94 |
95 | def test_rotation(self):
96 | data_rotated, seg_rotated = augment_rot90(np.copy(self.data_3D), np.copy(self.seg_3D), num_rot=self.num_rot,
97 | axes=[0, 1])
98 |
99 | for i in range(self.data_3D.shape[1]):
100 | self.assertTrue(np.array_equal(self.data_3D[:, i, :, :], np.flip(data_rotated[:, :, i, :], axis=1)))
101 | self.assertTrue(np.array_equal(self.seg_3D[:, i, :, :], np.flip(seg_rotated[:, :, i, :], axis=1)))
102 |
103 | def test_randomness_rotation_axis(self):
104 | tmp = 0
105 | for j in range(100):
106 | data_rotated, seg_rotated = augment_rot90(np.copy(self.data_3D), np.copy(self.seg_3D), num_rot=self.num_rot,
107 | axes=[0, 1, 2])
108 | if np.array_equal(self.data_3D[:, 0, :, :], np.flip(data_rotated[:, :, 0, :], axis=1)):
109 | tmp += 1
110 | self.assertAlmostEqual(tmp, 33, places=2)
111 |
112 | def test_rotation_list(self):
113 | num_rot = [1, 3]
114 | data_rotated, seg_rotated = augment_rot90(np.copy(self.data_3D), np.copy(self.seg_3D), num_rot=num_rot,
115 | axes=[0, 1])
116 | tmp = 0
117 | for i in range(self.data_3D.shape[1]):
118 | # check for normal and inverse rotations
119 | normal_rotated = np.array_equal(self.data_3D[:, i, :, :], data_rotated[:, :, -i-1, :])
120 | inverse_rotated = np.array_equal(self.data_3D[:, i, :, :], np.flip(data_rotated[:, :, i, :], axis=1))
121 | if normal_rotated:
122 | tmp += 1
123 | self.assertTrue(normal_rotated or inverse_rotated)
124 | self.assertTrue(np.array_equal(self.seg_3D[:, i, :, :], seg_rotated[:, :, -i - 1, :]) or
125 | np.array_equal(self.seg_3D[:, i, :, :], np.flip(seg_rotated[:, :, i, :], axis=1)))
126 |
127 | def test_randomness_rotation_number(self):
128 | tmp = 0
129 | num_rot = [1, 3]
130 | n_iter = 1000
131 | for j in range(n_iter):
132 | data_rotated, seg_rotated = augment_rot90(np.copy(self.data_3D), np.copy(self.seg_3D), num_rot=num_rot,
133 | axes=[0, 1])
134 | normal_rotated = np.array_equal(self.data_3D[:, 0, :, :], data_rotated[:, :, - 1, :])
135 | if normal_rotated:
136 | tmp += 1
137 | self.assertAlmostEqual(tmp, n_iter / 2., delta=20)
138 |
139 |
140 | if __name__ == '__main__':
141 | unittest.main()
142 |
--------------------------------------------------------------------------------
/batchgenerators/transforms/channel_selection_transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import numpy as np
17 | from warnings import warn
18 | from batchgenerators.transforms.abstract_transforms import AbstractTransform
19 |
20 |
21 | class DataChannelSelectionTransform(AbstractTransform):
22 | """Selects color channels from the batch and discards the others.
23 |
24 | Args:
25 | channels (list of int): List of channels to be kept.
26 |
27 | """
28 |
29 | def __init__(self, channels, data_key="data"):
30 | self.data_key = data_key
31 | self.channels = channels
32 |
33 | def __call__(self, **data_dict):
34 | data_dict[self.data_key] = data_dict[self.data_key][:, self.channels]
35 | return data_dict
36 |
37 |
38 | class SegChannelSelectionTransform(AbstractTransform):
39 | """Segmentations may have more than one channel. This transform selects segmentation channels
40 |
41 | Args:
42 | channels (list of int): List of channels to be kept.
43 |
44 | """
45 |
46 | def __init__(self, channels, keep_discarded_seg=False, label_key="seg"):
47 | self.label_key = label_key
48 | self.channels = channels
49 | self.keep_discarded = keep_discarded_seg
50 |
51 | def __call__(self, **data_dict):
52 | seg = data_dict.get(self.label_key)
53 |
54 | if seg is None:
55 | warn("You used SegChannelSelectionTransform but there is no 'seg' key in your data_dict, returning "
56 | "data_dict unmodified", Warning)
57 | else:
58 | if self.keep_discarded:
59 | discarded_seg_idx = [i for i in range(len(seg[0])) if i not in self.channels]
60 | data_dict['discarded_seg'] = seg[:, discarded_seg_idx]
61 | data_dict[self.label_key] = seg[:, self.channels]
62 | return data_dict
63 |
64 |
65 | class SegChannelMergeTransform(AbstractTransform):
66 | """Merge selected channels of a onehot segmentation. Will merge into lowest index.
67 |
68 | Args:
69 | channels (list of int): List of channels to be merged.
70 |
71 | """
72 |
73 | def __init__(self, channels, keep_discarded_seg=False, label_key="seg", fill_value=1):
74 | self.label_key = label_key
75 | self.channels = sorted(channels)
76 | self.keep_discarded = keep_discarded_seg
77 | self.fill_value = fill_value
78 |
79 | def __call__(self, **data_dict):
80 | seg = data_dict.get(self.label_key)
81 |
82 | if seg is None:
83 | warn("You used SegChannelSelectionTransform but there is no 'seg' key in your data_dict, returning data_dict unmodified", Warning)
84 | else:
85 | if self.keep_discarded:
86 | data_dict['discarded_seg'] = seg[:, self.channels[1:]]
87 | all_channels = list(range(seg.shape[1]))
88 | for i in self.channels[1:]:
89 | seg[:, self.channels[0]][seg[:, i] != 0] = self.fill_value
90 | all_channels.remove(i)
91 | data_dict[self.label_key] = seg[:, all_channels]
92 | return data_dict
93 |
94 |
95 | class SegChannelRandomSwapTransform(AbstractTransform):
96 | """Randomly swap two segmentation channels.
97 |
98 | Args:
99 | axis1 (int): First axis for swap
100 | axis2 (int): Second axis for swap
101 | swap_probability (float): Probability for swap
102 |
103 | """
104 |
105 | def __init__(self, axis1, axis2, swap_probability=0.5, label_key="seg"):
106 | self.axis1 = axis1
107 | self.axis2 = axis2
108 | self.swap_probability = swap_probability
109 | self.label_key = label_key
110 |
111 | def __call__(self, **data_dict):
112 | seg = data_dict.get(self.label_key)
113 |
114 | if seg is None:
115 | warn("You used SegChannelSelectionTransform but there is no 'seg' key in your data_dict, returning "
116 | "data_dict unmodified", Warning)
117 | else:
118 | random_number = np.random.rand()
119 | if random_number < self.swap_probability:
120 | seg[:, [self.axis1, self.axis2]] = seg[:, [self.axis2, self.axis1]]
121 | data_dict[self.label_key] = seg
122 | return data_dict
123 |
124 |
125 | class SegChannelRandomDuplicateTransform(AbstractTransform):
126 | """Creates an additional seg channel full of zeros and randomly swaps it with the base channel.
127 |
128 | Args:
129 | axis (int): Axis to be duplicated
130 | swap_probability (float): Probability for swap
131 |
132 | """
133 |
134 | def __init__(self, axis, swap_probability=0.5, label_key="seg"):
135 | self.axis = axis
136 | self.swap_probability = swap_probability
137 | self.label_key = label_key
138 |
139 | def __call__(self, **data_dict):
140 | seg = data_dict.get(self.label_key)
141 |
142 | if seg is None:
143 | warn("You used SegChannelSelectionTransform but there is no 'seg' key in your data_dict, returning "
144 | "data_dict unmodified", Warning)
145 | else:
146 | seg_shape = list(seg.shape)
147 | seg_shape[1] = 1
148 | seg = np.concatenate([seg, np.zeros(seg_shape, dtype=seg.dtype)], 1)
149 | random_number = np.random.rand()
150 | if random_number < self.swap_probability:
151 | seg[:, [self.axis, -1]] = seg[:, [-1, self.axis]]
152 | data_dict[self.label_key] = seg
153 | return data_dict
154 |
155 |
156 | class SegLabelSelectionBinarizeTransform(AbstractTransform):
157 | """Will create a binary segmentation, with the selected labels in the foreground.
158 |
159 | Args:
160 | label (int, list of int): Foreground label(s)
161 |
162 | """
163 |
164 | def __init__(self, label, label_key="seg"):
165 | self.label_key = label_key
166 | if isinstance(label, int):
167 | self.label = [label]
168 | else:
169 | self.label = sorted(label)
170 |
171 | def __call__(self, **data_dict):
172 | seg = data_dict.get(self.label_key)
173 |
174 | if seg is None:
175 | warn("You used SegLabelSelectionBinarizeTransform but there is no 'seg' key in your data_dict, returning "
176 | "data_dict unmodified", Warning)
177 | else:
178 | discard_labels = set(np.unique(seg)) - set(self.label) - set([0])
179 | for label in discard_labels:
180 | seg[seg == label] = 0
181 | for label in self.label:
182 | seg[seg == label] = 1
183 | data_dict[self.label_key] = seg
184 | return data_dict
185 |
--------------------------------------------------------------------------------
/batchgenerators/augmentations/color_augmentations.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from builtins import range
17 | from typing import Tuple, Union, Callable
18 |
19 | import numpy as np
20 | from batchgenerators.augmentations.utils import general_cc_var_num_channels, illumination_jitter
21 |
22 |
23 | def augment_contrast(data_sample: np.ndarray,
24 | contrast_range: Union[Tuple[float, float], Callable[[], float]] = (0.75, 1.25),
25 | preserve_range: bool = True,
26 | per_channel: bool = True,
27 | p_per_channel: float = 1) -> np.ndarray:
28 | if not per_channel:
29 | if callable(contrast_range):
30 | factor = contrast_range()
31 | else:
32 | if np.random.random() < 0.5 and contrast_range[0] < 1:
33 | factor = np.random.uniform(contrast_range[0], 1)
34 | else:
35 | factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1])
36 |
37 | for c in range(data_sample.shape[0]):
38 | if np.random.uniform() < p_per_channel:
39 | mn = data_sample[c].mean()
40 | if preserve_range:
41 | minm = data_sample[c].min()
42 | maxm = data_sample[c].max()
43 |
44 | data_sample[c] = (data_sample[c] - mn) * factor + mn
45 |
46 | if preserve_range:
47 | data_sample[c][data_sample[c] < minm] = minm
48 | data_sample[c][data_sample[c] > maxm] = maxm
49 | else:
50 | for c in range(data_sample.shape[0]):
51 | if np.random.uniform() < p_per_channel:
52 | if callable(contrast_range):
53 | factor = contrast_range()
54 | else:
55 | if np.random.random() < 0.5 and contrast_range[0] < 1:
56 | factor = np.random.uniform(contrast_range[0], 1)
57 | else:
58 | factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1])
59 |
60 | mn = data_sample[c].mean()
61 | if preserve_range:
62 | minm = data_sample[c].min()
63 | maxm = data_sample[c].max()
64 |
65 | data_sample[c] = (data_sample[c] - mn) * factor + mn
66 |
67 | if preserve_range:
68 | data_sample[c][data_sample[c] < minm] = minm
69 | data_sample[c][data_sample[c] > maxm] = maxm
70 | return data_sample
71 |
72 |
73 | def augment_brightness_additive(data_sample, mu:float, sigma:float , per_channel:bool=True, p_per_channel:float=1.):
74 | """
75 | data_sample must have shape (c, x, y(, z)))
76 | :param data_sample:
77 | :param mu:
78 | :param sigma:
79 | :param per_channel:
80 | :param p_per_channel:
81 | :return:
82 | """
83 | if not per_channel:
84 | rnd_nb = np.random.normal(mu, sigma)
85 | for c in range(data_sample.shape[0]):
86 | if np.random.uniform() <= p_per_channel:
87 | data_sample[c] += rnd_nb
88 | else:
89 | for c in range(data_sample.shape[0]):
90 | if np.random.uniform() <= p_per_channel:
91 | rnd_nb = np.random.normal(mu, sigma)
92 | data_sample[c] += rnd_nb
93 | return data_sample
94 |
95 |
96 | def augment_brightness_multiplicative(data_sample, multiplier_range=(0.5, 2), per_channel=True):
97 | multiplier = np.random.uniform(multiplier_range[0], multiplier_range[1])
98 | if not per_channel:
99 | data_sample *= multiplier
100 | else:
101 | for c in range(data_sample.shape[0]):
102 | multiplier = np.random.uniform(multiplier_range[0], multiplier_range[1])
103 | data_sample[c] *= multiplier
104 | return data_sample
105 |
106 |
107 | def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon=1e-7, per_channel=False,
108 | retain_stats: Union[bool, Callable[[], bool]] = False):
109 | if invert_image:
110 | data_sample = - data_sample
111 |
112 | if not per_channel:
113 | retain_stats_here = retain_stats() if callable(retain_stats) else retain_stats
114 | if retain_stats_here:
115 | mn = data_sample.mean()
116 | sd = data_sample.std()
117 | if np.random.random() < 0.5 and gamma_range[0] < 1:
118 | gamma = np.random.uniform(gamma_range[0], 1)
119 | else:
120 | gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1])
121 | minm = data_sample.min()
122 | rnge = data_sample.max() - minm
123 | data_sample = np.power(((data_sample - minm) / float(rnge + epsilon)), gamma) * rnge + minm
124 | if retain_stats_here:
125 | data_sample = data_sample - data_sample.mean()
126 | data_sample = data_sample / (data_sample.std() + 1e-8) * sd
127 | data_sample = data_sample + mn
128 | else:
129 | for c in range(data_sample.shape[0]):
130 | retain_stats_here = retain_stats() if callable(retain_stats) else retain_stats
131 | if retain_stats_here:
132 | mn = data_sample[c].mean()
133 | sd = data_sample[c].std()
134 | if np.random.random() < 0.5 and gamma_range[0] < 1:
135 | gamma = np.random.uniform(gamma_range[0], 1)
136 | else:
137 | gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1])
138 | minm = data_sample[c].min()
139 | rnge = data_sample[c].max() - minm
140 | data_sample[c] = np.power(((data_sample[c] - minm) / float(rnge + epsilon)), gamma) * float(rnge + epsilon) + minm
141 | if retain_stats_here:
142 | data_sample[c] = data_sample[c] - data_sample[c].mean()
143 | data_sample[c] = data_sample[c] / (data_sample[c].std() + 1e-8) * sd
144 | data_sample[c] = data_sample[c] + mn
145 | if invert_image:
146 | data_sample = - data_sample
147 | return data_sample
148 |
149 |
150 | def augment_illumination(data, white_rgb):
151 | idx = np.random.choice(len(white_rgb), data.shape[0])
152 | for sample in range(data.shape[0]):
153 | _, img = general_cc_var_num_channels(data[sample], 0, 5, 0, None, 1., 7, False)
154 | rgb = np.array(white_rgb[idx[sample]]) * np.sqrt(3)
155 | for c in range(data[sample].shape[0]):
156 | data[sample, c] = img[c] * rgb[c]
157 | return data
158 |
159 |
160 | def augment_PCA_shift(data, U, s, sigma=0.2):
161 | for sample in range(data.shape[0]):
162 | data[sample] = illumination_jitter(data[sample], U, s, sigma)
163 | data[sample] -= data[sample].min()
164 | data[sample] /= data[sample].max()
165 | return data
166 |
--------------------------------------------------------------------------------
/tests/test_axis_mirroring.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import unittest
17 | import unittest2
18 | import numpy as np
19 | from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
20 | from skimage import data
21 |
22 | from tests.DataGenerators import BasicDataLoader
23 | from batchgenerators.transforms.spatial_transforms import MirrorTransform
24 |
25 |
26 | class TestMirrorAxis(unittest2.TestCase):
27 | def setUp(self):
28 | self.seed = 1234
29 |
30 | self.batch_size = 10
31 | self.num_batches = 1000
32 |
33 | np.random.seed(self.seed)
34 |
35 | ### 2D initialiazations
36 |
37 | cam = data.camera()
38 | self.cam = cam[np.newaxis, np.newaxis, :, :]
39 |
40 | self.cam_left = self.cam[:, :, :, ::-1]
41 | self.cam_updown = self.cam[:, :, ::-1, :]
42 | self.cam_updown_left = self.cam[:, :, ::-1, ::-1]
43 |
44 | self.x_2D = self.cam
45 | self.y_2D = self.cam
46 |
47 | ### 3D initialiazations
48 |
49 | self.cam_3D = np.random.rand(20, 20, 20)[np.newaxis, np.newaxis, :, :, :]
50 |
51 | self.cam_3D_left = self.cam_3D[:, :, :, ::-1, :]
52 | self.cam_3D_updown = self.cam_3D[:, :, ::-1, :, :]
53 | self.cam_3D_updown_left = self.cam_3D[:, :, ::-1, ::-1, :]
54 |
55 | self.cam_3D_left_z = self.cam_3D_left[:, :, :, :, ::-1]
56 | self.cam_3D_updown_z = self.cam_3D_updown[:, :, :, :, ::-1]
57 | self.cam_3D_updown_left_z = self.cam_3D_updown_left[:, :, :, :, ::-1]
58 | self.cam_3D_z = self.cam_3D[:, :, :, :, ::-1]
59 |
60 | self.x_3D = self.cam_3D
61 | self.y_3D = self.cam_3D
62 |
63 |
64 | def test_random_distributions_2D(self):
65 | ### test whether all 4 possible mirrorings occur in approximately equal frquencies in 2D
66 |
67 | batch_gen = BasicDataLoader((self.x_2D, self.y_2D), self.batch_size, number_of_threads_in_multithreaded=None)
68 | batch_gen = SingleThreadedAugmenter(batch_gen, MirrorTransform((0, 1)))
69 |
70 | counts = np.zeros(shape=(4,))
71 |
72 | for b in range(self.num_batches):
73 | batch = next(batch_gen)
74 |
75 | for ix in range(self.batch_size):
76 | if (batch['data'][ix, :, :, :] == self.cam_left).all():
77 | counts[0] = counts[0] + 1
78 |
79 | elif (batch['data'][ix, :, :, :] == self.cam_updown).all():
80 | counts[1] = counts[1] + 1
81 |
82 | elif (batch['data'][ix, :, :, :] == self.cam_updown_left).all():
83 | counts[2] = counts[2] + 1
84 |
85 | elif (batch['data'][ix, :, :, :] == self.cam).all():
86 | counts[3] = counts[3] + 1
87 |
88 | self.assertTrue([1 if (2200 < c < 2800) else 0 for c in counts] == [1]*4, "2D Images were not mirrored along "
89 | "all axes with equal probability. "
90 | "This may also indicate that "
91 | "mirroring is not working")
92 |
93 |
94 | def test_segmentations_2D(self):
95 | ### test whether segmentations are mirrored coherently with images
96 |
97 | batch_gen = BasicDataLoader((self.x_2D, self.y_2D), self.batch_size, number_of_threads_in_multithreaded=None)
98 | batch_gen = SingleThreadedAugmenter(batch_gen, MirrorTransform((0, 1)))
99 |
100 | equivalent = True
101 |
102 | for b in range(self.num_batches):
103 | batch = next(batch_gen)
104 | for ix in range(self.batch_size):
105 | if (batch['data'][ix] != batch['seg'][ix]).all():
106 | equivalent = False
107 |
108 | self.assertTrue(equivalent, "2D images and seg were not mirrored in the same way (they should though because "
109 | "seg needs to match the corresponding data")
110 |
111 |
112 | def test_random_distributions_3D(self):
113 | ### test whether all 8 possible mirrorings occur in approximately equal frquencies in 3D case
114 |
115 | batch_gen = BasicDataLoader((self.x_3D, self.y_3D), self.batch_size, number_of_threads_in_multithreaded=None)
116 | batch_gen = SingleThreadedAugmenter(batch_gen, MirrorTransform((0, 1, 2)))
117 |
118 | counts = np.zeros(shape=(8,))
119 |
120 | for b in range(self.num_batches):
121 | batch = next(batch_gen)
122 | for ix in range(self.batch_size):
123 | if (batch['data'][ix, :, :, :, :] == self.cam_3D_left).all():
124 | counts[0] = counts[0] + 1
125 |
126 | elif (batch['data'][ix, :, :, :, :] == self.cam_3D_updown).all():
127 | counts[1] = counts[1] + 1
128 |
129 | elif (batch['data'][ix, :, :, :, :] == self.cam_3D_updown_left).all():
130 | counts[2] = counts[2] + 1
131 |
132 | elif (batch['data'][ix, :, :, :, :] == self.cam_3D).all():
133 | counts[3] = counts[3] + 1
134 |
135 | elif (batch['data'][ix, :, :, :, :] == self.cam_3D_left_z).all():
136 | counts[4] = counts[1] + 1
137 |
138 | elif (batch['data'][ix, :, :, :, :] == self.cam_3D_updown_z).all():
139 | counts[5] = counts[1] + 1
140 |
141 | elif (batch['data'][ix, :, :, :, :] == self.cam_3D_updown_left_z).all():
142 | counts[6] = counts[2] + 1
143 |
144 | elif (batch['data'][ix, :, :, :, :] == self.cam_3D_z).all():
145 | counts[7] = counts[3] + 1
146 |
147 | self.assertTrue([1 if (1000 < c < 1400) else 0 for c in counts] == [1]*8, "3D Images were not mirrored along "
148 | "all axes with equal probability. "
149 | "This may also indicate that "
150 | "mirroring is not working")
151 |
152 |
153 | def test_segmentations_3D(self):
154 | ### test whether segmentations are rotated coherently with images
155 |
156 | batch_gen = BasicDataLoader((self.x_3D, self.y_3D), self.batch_size, number_of_threads_in_multithreaded=None)
157 | batch_gen = SingleThreadedAugmenter(batch_gen, MirrorTransform((0, 1, 2)))
158 |
159 | equivalent = True
160 |
161 | for b in range(self.num_batches):
162 | batch = next(batch_gen)
163 | for ix in range(self.batch_size):
164 | if (batch['data'][ix] != batch['seg'][ix]).all():
165 | equivalent = False
166 |
167 | self.assertTrue(equivalent, "3D images and seg were not mirrored in the same way (they should though because "
168 | "seg needs to match the corresponding data")
169 |
170 |
171 | if __name__ == '__main__':
172 | unittest.main()
173 |
174 |
--------------------------------------------------------------------------------
/batchgenerators/augmentations/crop_and_pad_augmentations.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from builtins import range
17 | import numpy as np
18 | from batchgenerators.augmentations.utils import pad_nd_image
19 |
20 |
21 | def center_crop(data, crop_size, seg=None):
22 | return crop(data, seg, crop_size, 0, 'center')
23 |
24 |
25 | def get_lbs_for_random_crop(crop_size, data_shape, margins):
26 | """
27 |
28 | :param crop_size:
29 | :param data_shape: (b,c,x,y(,z)) must be the whole thing!
30 | :param margins:
31 | :return:
32 | """
33 | lbs = []
34 | for i in range(len(data_shape) - 2):
35 | if data_shape[i+2] - crop_size[i] - margins[i] > margins[i]:
36 | lbs.append(np.random.randint(margins[i], data_shape[i+2] - crop_size[i] - margins[i]))
37 | else:
38 | lbs.append((data_shape[i+2] - crop_size[i]) // 2)
39 | return lbs
40 |
41 |
42 | def get_lbs_for_center_crop(crop_size, data_shape):
43 | """
44 | :param crop_size:
45 | :param data_shape: (b,c,x,y(,z)) must be the whole thing!
46 | :return:
47 | """
48 | lbs = []
49 | for i in range(len(data_shape) - 2):
50 | lbs.append((data_shape[i + 2] - crop_size[i]) // 2)
51 | return lbs
52 |
53 |
54 | def crop(data, seg=None, crop_size=128, margins=(0, 0, 0), crop_type="center",
55 | pad_mode='constant', pad_kwargs={'constant_values': 0},
56 | pad_mode_seg='constant', pad_kwargs_seg={'constant_values': 0}):
57 | """
58 | crops data and seg (seg may be None) to crop_size. Whether this will be achieved via center or random crop is
59 | determined by crop_type. Margin will be respected only for random_crop and will prevent the crops form being closer
60 | than margin to the respective image border. crop_size can be larger than data_shape - margin -> data/seg will be
61 | padded with zeros in that case. margins can be negative -> results in padding of data/seg followed by cropping with
62 | margin=0 for the appropriate axes
63 |
64 | :param data: b, c, x, y(, z)
65 | :param seg:
66 | :param crop_size:
67 | :param margins: distance from each border, can be int or list/tuple of ints (one element for each dimension).
68 | Can be negative (data/seg will be padded if needed)
69 | :param crop_type: random or center
70 | :return:
71 | """
72 | if not isinstance(data, (list, tuple, np.ndarray)):
73 | raise TypeError("data has to be either a numpy array or a list")
74 |
75 | data_shape = tuple([len(data)] + list(data[0].shape))
76 | data_dtype = data[0].dtype
77 | dim = len(data_shape) - 2
78 |
79 | if seg is not None:
80 | seg_shape = tuple([len(seg)] + list(seg[0].shape))
81 | seg_dtype = seg[0].dtype
82 |
83 | if not isinstance(seg, (list, tuple, np.ndarray)):
84 | raise TypeError("data has to be either a numpy array or a list")
85 |
86 | assert all([i == j for i, j in zip(seg_shape[2:], data_shape[2:])]), "data and seg must have the same spatial " \
87 | "dimensions. Data: %s, seg: %s" % \
88 | (str(data_shape), str(seg_shape))
89 |
90 | if type(crop_size) not in (tuple, list, np.ndarray):
91 | crop_size = [crop_size] * dim
92 | else:
93 | assert len(crop_size) == len(
94 | data_shape) - 2, "If you provide a list/tuple as center crop make sure it has the same dimension as your " \
95 | "data (2d/3d)"
96 |
97 | if not isinstance(margins, (np.ndarray, tuple, list)):
98 | margins = [margins] * dim
99 |
100 | data_return = np.zeros([data_shape[0], data_shape[1]] + list(crop_size), dtype=data_dtype)
101 | if seg is not None:
102 | seg_return = np.zeros([seg_shape[0], seg_shape[1]] + list(crop_size), dtype=seg_dtype)
103 | else:
104 | seg_return = None
105 |
106 | for b in range(data_shape[0]):
107 | data_shape_here = [data_shape[0]] + list(data[b].shape)
108 | if seg is not None:
109 | seg_shape_here = [seg_shape[0]] + list(seg[b].shape)
110 |
111 | if crop_type == "center":
112 | lbs = get_lbs_for_center_crop(crop_size, data_shape_here)
113 | elif crop_type == "random":
114 | lbs = get_lbs_for_random_crop(crop_size, data_shape_here, margins)
115 | else:
116 | raise NotImplementedError("crop_type must be either center or random")
117 |
118 | need_to_pad = [[0, 0]] + [[abs(min(0, lbs[d])),
119 | abs(min(0, data_shape_here[d + 2] - (lbs[d] + crop_size[d])))]
120 | for d in range(dim)]
121 |
122 | # we should crop first, then pad -> reduces i/o for memmaps, reduces RAM usage and improves speed
123 | ubs = [min(lbs[d] + crop_size[d], data_shape_here[d+2]) for d in range(dim)]
124 | lbs = [max(0, lbs[d]) for d in range(dim)]
125 |
126 | slicer_data = [slice(0, data_shape_here[1])] + [slice(lbs[d], ubs[d]) for d in range(dim)]
127 | data_cropped = data[b][tuple(slicer_data)]
128 |
129 | if seg_return is not None:
130 | slicer_seg = [slice(0, seg_shape_here[1])] + [slice(lbs[d], ubs[d]) for d in range(dim)]
131 | seg_cropped = seg[b][tuple(slicer_seg)]
132 |
133 | if any([i > 0 for j in need_to_pad for i in j]):
134 | data_return[b] = np.pad(data_cropped, need_to_pad, pad_mode, **pad_kwargs)
135 | if seg_return is not None:
136 | seg_return[b] = np.pad(seg_cropped, need_to_pad, pad_mode_seg, **pad_kwargs_seg)
137 | else:
138 | data_return[b] = data_cropped
139 | if seg_return is not None:
140 | seg_return[b] = seg_cropped
141 |
142 | return data_return, seg_return
143 |
144 |
145 | def random_crop(data, seg=None, crop_size=128, margins=[0, 0, 0]):
146 | return crop(data, seg, crop_size, margins, 'random')
147 |
148 |
149 | def pad_nd_image_and_seg(data, seg, new_shape=None, must_be_divisible_by=None, pad_mode_data='constant',
150 | np_pad_kwargs_data=None, pad_mode_seg='constant', np_pad_kwargs_seg=None):
151 | """
152 | Pads data and seg to new_shape. new_shape is thereby understood as min_shape (if data/seg is already larger then
153 | new_shape the shape stays the same for the dimensions this applies)
154 | :param data:
155 | :param seg:
156 | :param new_shape: if none then only must_be_divisible_by is applied
157 | :param must_be_divisible_by: UNet like architectures sometimes require the input to be divisibly by some number. This
158 | will modify new_shape if new_shape is not divisibly by this (by increasing it accordingly).
159 | must_be_divisible_by should be a list of int (one for each spatial dimension) and this list must have the same
160 | length as new_shape
161 | :param pad_mode_data: see np.pad
162 | :param np_pad_kwargs_data:see np.pad
163 | :param pad_mode_seg:see np.pad
164 | :param np_pad_kwargs_seg:see np.pad
165 | :return:
166 | """
167 | sample_data = pad_nd_image(data, new_shape, mode=pad_mode_data, kwargs=np_pad_kwargs_data,
168 | return_slicer=False, shape_must_be_divisible_by=must_be_divisible_by)
169 | if seg is not None:
170 | sample_seg = pad_nd_image(seg, new_shape, mode=pad_mode_seg, kwargs=np_pad_kwargs_seg,
171 | return_slicer=False, shape_must_be_divisible_by=must_be_divisible_by)
172 | else:
173 | sample_seg = None
174 | return sample_data, sample_seg
175 |
--------------------------------------------------------------------------------
/tests/test_multithreaded_augmenter.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import unittest
16 | from time import sleep
17 |
18 | import numpy as np
19 | from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
20 | from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
21 | from batchgenerators.examples.multithreaded_dataloading import DummyDL, DummyDLWithShuffle
22 | from batchgenerators.transforms.abstract_transforms import Compose
23 | from batchgenerators.transforms.spatial_transforms import MirrorTransform, TransposeAxesTransform
24 | from batchgenerators.transforms.utility_transforms import NumpyToTensor
25 | from skimage.data import camera, checkerboard, astronaut, binary_blobs, coins
26 | from skimage.transform import resize
27 | from copy import deepcopy
28 |
29 |
30 | class DummyDL2DImage(SlimDataLoaderBase):
31 | def __init__(self, batch_size, num_threads=8):
32 | data = []
33 | target_shape = (224, 224)
34 |
35 | c = camera()
36 | c = resize(c.astype(np.float64), target_shape, 1, anti_aliasing=False, clip=True, mode='reflect').astype(np.float32)
37 | data.append(c[None])
38 |
39 | c = checkerboard()
40 | c = resize(c.astype(np.float64), target_shape, 1, anti_aliasing=False, clip=True, mode='reflect').astype(np.float32)
41 | data.append(c[None])
42 |
43 | c = astronaut().mean(-1)
44 | c = resize(c.astype(np.float64), target_shape, 1, anti_aliasing=False, clip=True, mode='reflect').astype(np.float32)
45 | data.append(c[None])
46 |
47 | c = binary_blobs()
48 | c = resize(c.astype(np.float64), target_shape, 1, anti_aliasing=False, clip=True, mode='reflect').astype(np.float32)
49 | data.append(c[None])
50 |
51 | c = coins()
52 | c = resize(c.astype(np.float64), target_shape, 1, anti_aliasing=False, clip=True, mode='reflect').astype(np.float32)
53 | data.append(c[None])
54 | data = np.stack(data)
55 | super(DummyDL2DImage, self).__init__(data, batch_size, num_threads)
56 |
57 | def generate_train_batch(self):
58 | idx = np.random.choice(len(self._data), self.batch_size)
59 | res = []
60 | for i in idx:
61 | res.append(self._data[i:i+1])
62 | res = np.vstack(res)
63 | return {'data': res}
64 |
65 |
66 | class TestMultiThreadedAugmenter(unittest.TestCase):
67 | """
68 | This test is inspired by the multithreaded example I did a while back
69 | """
70 | def setUp(self):
71 | np.random.seed(1234)
72 | self.num_threads = 4
73 | self.dl = DummyDL(self.num_threads)
74 | self.dl_with_shuffle = DummyDLWithShuffle(self.num_threads)
75 | self.dl_images = DummyDL2DImage(4, self.num_threads)
76 |
77 | def test_no_crash(self):
78 | """
79 | This one should just not crash, that's all
80 | :return:
81 | """
82 | dl = self.dl_images
83 | mt_dl = MultiThreadedAugmenter(dl, None, self.num_threads, 1, None, False)
84 |
85 | for _ in range(20):
86 | _ = mt_dl.next()
87 |
88 | def test_DummyDL(self):
89 | """
90 | DummyDL must return numbers from 0 to 99 in ascending order
91 | :return:
92 | """
93 | dl = DummyDL(1)
94 | res = []
95 | for i in dl:
96 | res.append(i)
97 |
98 | assert len(res) == 100
99 | res_copy = deepcopy(res)
100 | res.sort()
101 | assert all((i == j for i, j in zip(res, res_copy)))
102 | assert all((i == j for i, j in zip(res, np.arange(0, 100))))
103 |
104 | def test_order(self):
105 | """
106 | Coordinating workers in a multiprocessing envrionment is difficult. We want DummyDL in a multithreaded
107 | environment to still give us the numbers from 0 to 99 in ascending order
108 | :return:
109 | """
110 | dl = self.dl
111 | mt = MultiThreadedAugmenter(dl, None, self.num_threads, 1, None, False)
112 |
113 | res = []
114 | for i in mt:
115 | res.append(i)
116 |
117 | assert len(res) == 100
118 | res_copy = deepcopy(res)
119 | res.sort()
120 | assert all((i == j for i, j in zip(res, res_copy)))
121 | assert all((i == j for i, j in zip(res, np.arange(0, 100))))
122 |
123 | def test_restart_and_order(self):
124 | """
125 | Coordinating workers in a multiprocessing envrionment is difficult. We want DummyDL in a multithreaded
126 | environment to still give us the numbers from 0 to 99 in ascending order.
127 |
128 | We want the MultiThreadedAugmenter to restart and return the same result in each run
129 | :return:
130 | """
131 | dl = self.dl
132 | mt = MultiThreadedAugmenter(dl, None, self.num_threads, 1, None, False)
133 |
134 | res = []
135 | for i in mt:
136 | res.append(i)
137 |
138 | assert len(res) == 100
139 | res_copy = deepcopy(res)
140 | res.sort()
141 | assert all((i == j for i, j in zip(res, res_copy)))
142 | assert all((i == j for i, j in zip(res, np.arange(0, 100))))
143 |
144 | res = []
145 | for i in mt:
146 | res.append(i)
147 |
148 | assert len(res) == 100
149 | res_copy = deepcopy(res)
150 | res.sort()
151 | assert all((i == j for i, j in zip(res, res_copy)))
152 | assert all((i == j for i, j in zip(res, np.arange(0, 100))))
153 |
154 | res = []
155 | for i in mt:
156 | res.append(i)
157 |
158 | assert len(res) == 100
159 | res_copy = deepcopy(res)
160 | res.sort()
161 | assert all((i == j for i, j in zip(res, res_copy)))
162 | assert all((i == j for i, j in zip(res, np.arange(0, 100))))
163 |
164 | def test_image_pipeline_and_pin_memory(self):
165 | '''
166 | This just should not crash
167 | :return:
168 | '''
169 | try:
170 | import torch
171 | except ImportError:
172 | '''dont test if torch is not installed'''
173 | return
174 |
175 |
176 | tr_transforms = []
177 | tr_transforms.append(MirrorTransform())
178 | tr_transforms.append(TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5))
179 | tr_transforms.append(NumpyToTensor(keys='data', cast_to='float'))
180 |
181 | composed = Compose(tr_transforms)
182 |
183 | dl = self.dl_images
184 | mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, True)
185 |
186 | for _ in range(50):
187 | res = mt.next()
188 |
189 | assert isinstance(res['data'], torch.Tensor)
190 | assert res['data'].is_pinned()
191 |
192 | # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
193 | # the success of the test but it does not look pretty)
194 | sleep(2)
195 |
196 | def test_image_pipeline(self):
197 | '''
198 | This just should not crash
199 | :return:
200 | '''
201 |
202 | tr_transforms = []
203 | tr_transforms.append(MirrorTransform())
204 | tr_transforms.append(TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5))
205 |
206 | composed = Compose(tr_transforms)
207 |
208 | dl = self.dl_images
209 | mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, False)
210 |
211 | for _ in range(50):
212 | res = mt.next()
213 |
214 | # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
215 | # the success of the test but it does not look pretty)
216 | sleep(2)
217 |
218 |
219 | if __name__ == '__main__':
220 | from multiprocessing import freeze_support
221 | freeze_support()
222 | unittest.main()
223 |
--------------------------------------------------------------------------------
/batchgenerators/examples/brats2017/brats2017_dataloader_2D.py:
--------------------------------------------------------------------------------
1 | from time import time
2 |
3 | import numpy as np
4 | from batchgenerators.augmentations.crop_and_pad_augmentations import crop
5 | from batchgenerators.augmentations.utils import pad_nd_image
6 | from batchgenerators.dataloading.data_loader import DataLoader
7 | from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
8 | from batchgenerators.examples.brats2017.brats2017_dataloader_3D import get_list_of_patients, BraTS2017DataLoader3D, \
9 | get_train_transform
10 | from batchgenerators.examples.brats2017.config import brats_preprocessed_folder, num_threads_for_brats_example
11 | from batchgenerators.utilities.data_splitting import get_split_deterministic
12 |
13 |
14 | class BraTS2017DataLoader2D(DataLoader):
15 | def __init__(self, data, batch_size, patch_size, num_threads_in_multithreaded, seed_for_shuffle=1234, return_incomplete=False,
16 | shuffle=True):
17 | """
18 | data must be a list of patients as returned by get_list_of_patients (and split by get_split_deterministic)
19 |
20 | patch_size is the spatial size the retured batch will have
21 |
22 | """
23 | super().__init__(data, batch_size, num_threads_in_multithreaded, seed_for_shuffle, return_incomplete, shuffle,
24 | True)
25 | self.patch_size = patch_size
26 | self.num_modalities = 4
27 | self.indices = list(range(len(data)))
28 |
29 | @staticmethod
30 | def load_patient(patient):
31 | return BraTS2017DataLoader3D.load_patient(patient)
32 |
33 | def generate_train_batch(self):
34 | # DataLoader has its own methods for selecting what patients to use next, see its Documentation
35 | idx = self.get_indices()
36 | patients_for_batch = [self._data[i] for i in idx]
37 |
38 | # initialize empty array for data and seg
39 | data = np.zeros((self.batch_size, self.num_modalities, *self.patch_size), dtype=np.float32)
40 | seg = np.zeros((self.batch_size, 1, *self.patch_size), dtype=np.float32)
41 |
42 | metadata = []
43 | patient_names = []
44 |
45 | # iterate over patients_for_batch and include them in the batch
46 | for i, j in enumerate(patients_for_batch):
47 | patient_data, patient_metadata = self.load_patient(j)
48 |
49 | # patient data is a memmap. If we extract just one slice then just this one slice will be read from the
50 | # disk, so no worries!
51 | slice_idx = np.random.choice(patient_data.shape[1])
52 | patient_data = patient_data[:, slice_idx]
53 |
54 | # this will only pad patient_data if its shape is smaller than self.patch_size
55 | patient_data = pad_nd_image(patient_data, self.patch_size)
56 |
57 | # now random crop to self.patch_size
58 | # crop expects the data to be (b, c, x, y, z) but patient_data is (c, x, y, z) so we need to add one
59 | # dummy dimension in order for it to work (@Todo, could be improved)
60 | patient_data, patient_seg = crop(patient_data[:-1][None], patient_data[-1:][None], self.patch_size, crop_type="random")
61 |
62 | data[i] = patient_data[0]
63 | seg[i] = patient_seg[0]
64 |
65 | metadata.append(patient_metadata)
66 | patient_names.append(j)
67 |
68 | return {'data': data, 'seg':seg, 'metadata':metadata, 'names':patient_names}
69 |
70 |
71 | if __name__ == "__main__":
72 | patients = get_list_of_patients(brats_preprocessed_folder)
73 |
74 | train, val = get_split_deterministic(patients, fold=0, num_splits=5, random_state=12345)
75 |
76 | patch_size = (160, 160)
77 | batch_size = 48
78 |
79 | # I recommend you don't use 'iteration oder all training data' as epoch because in patch based training this is
80 | # really not super well defined. If you leave all arguments as default then each batch sill contain randomly
81 | # selected patients. Since we don't care about epochs here we can set num_threads_in_multithreaded to anything.
82 | dataloader = BraTS2017DataLoader2D(train, batch_size, patch_size, 1)
83 |
84 | batch = next(dataloader)
85 | try:
86 | from batchviewer import view_batch
87 | # batch viewer can show up to 4d tensors. We can show only one sample, but that should be sufficient here
88 | view_batch(np.concatenate((batch['data'][0], batch['seg'][0]), 0)[:, None])
89 | except ImportError:
90 | view_batch = None
91 | print("you can visualize batches with batchviewer. It's a nice and handy tool. You can get it here: "
92 | "https://github.com/FabianIsensee/BatchViewer")
93 |
94 | # now we have some DataLoader. Let's go an get some augmentations
95 |
96 | # first let's collect all shapes, you will see why later
97 | shapes = [BraTS2017DataLoader2D.load_patient(i)[0].shape[2:] for i in patients]
98 | max_shape = np.max(shapes, 0)
99 | max_shape = np.max((max_shape, patch_size), 0)
100 |
101 | # we create a new instance of DataLoader. This one will return batches of shape max_shape. Cropping/padding is
102 | # now done by SpatialTransform. If we do it this way we avoid border artifacts (the entire brain of all cases will
103 | # be in the batch and SpatialTransform will use zeros which is exactly what we have outside the brain)
104 | # this is viable here but not viable if you work with different data. If you work for example with CT scans that
105 | # can be up to 500x500x500 voxels large then you should do this differently. There, instead of using max_shape you
106 | # should estimate what shape you need to extract so that subsequent SpatialTransform does not introduce border
107 | # artifacts
108 | dataloader_train = BraTS2017DataLoader2D(train, batch_size, max_shape, 1)
109 |
110 | # during training I like to run a validation from time to time to see where I am standing. This is not a correct
111 | # validation because just like training this is patch-based but it's good enough. We don't do augmentation for the
112 | # validation, so patch_size is used as shape target here
113 | dataloader_validation = BraTS2017DataLoader2D(val, batch_size, patch_size, 1)
114 |
115 | tr_transforms = get_train_transform(patch_size)
116 |
117 | # finally we can create multithreaded transforms that we can actually use for training
118 | # we don't pin memory here because this is pytorch specific.
119 | tr_gen = MultiThreadedAugmenter(dataloader_train, tr_transforms, num_processes=num_threads_for_brats_example,
120 | num_cached_per_queue=3,
121 | seeds=None, pin_memory=False)
122 | # we need less processes for vlaidation because we dont apply transformations
123 | val_gen = MultiThreadedAugmenter(dataloader_validation, None,
124 | num_processes=max(1, num_threads_for_brats_example // 2), num_cached_per_queue=1,
125 | seeds=None,
126 | pin_memory=False)
127 |
128 | # lets start the MultiThreadedAugmenter. This is not necessary but allows them to start generating training
129 | # batches while other things run in the main thread
130 | tr_gen.restart()
131 | val_gen.restart()
132 |
133 | # now if this was a network training you would run epochs like this (remember tr_gen and val_gen generate
134 | # inifinite examples! Don't do "for batch in tr_gen:"!!!):
135 | num_batches_per_epoch = 10
136 | num_validation_batches_per_epoch = 3
137 | num_epochs = 5
138 | # let's run this to get a time on how long it takes
139 | time_per_epoch = []
140 | start = time()
141 | for epoch in range(num_epochs):
142 | start_epoch = time()
143 | for b in range(num_batches_per_epoch):
144 | batch = next(tr_gen)
145 | # do network training here with this batch
146 |
147 | for b in range(num_validation_batches_per_epoch):
148 | batch = next(val_gen)
149 | # run validation here
150 | end_epoch = time()
151 | time_per_epoch.append(end_epoch - start_epoch)
152 | end = time()
153 | total_time = end - start
154 | print("Running %d epochs took a total of %.2f seconds with time per epoch being %s" %
155 | (num_epochs, total_time, str(time_per_epoch)))
156 |
157 | # if you notice that you have CPU usage issues, reduce the probability with which the spatial transformations are
158 | # applied in get_train_transform (down to 0.1 for example). SpatialTransform is the most expensive transform
159 |
160 | # if you wish to visualize some augmented examples, install batchviewer and uncomment this
161 | if view_batch is not None:
162 | for _ in range(4):
163 | batch = next(tr_gen)
164 | view_batch(np.concatenate((batch['data'][0], batch['seg'][0]), 0)[:, None])
165 | else:
166 | print("Cannot visualize batches, install batchviewer first")
167 |
--------------------------------------------------------------------------------
/batchgenerators/examples/brats2017/brats2017_preprocessing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from batchgenerators.examples.brats2017.config import brats_preprocessed_folder, \
3 | brats_folder_with_downloaded_train_data, num_threads_for_brats_example
4 | from batchgenerators.utilities.file_and_folder_operations import *
5 |
6 | try:
7 | import SimpleITK as sitk
8 | except ImportError:
9 | print("You need to have SimpleITK installed to run this example!")
10 | raise ImportError("SimpleITK not found")
11 |
12 | from multiprocessing import Pool
13 |
14 |
15 | def get_list_of_files(base_dir):
16 | """
17 | returns a list of lists containing the filenames. The outer list contains all training examples. Each entry in the
18 | outer list is again a list pointing to the files of that training example in the following order:
19 | T1, T1c, T2, FLAIR, segmentation
20 | :param base_dir:
21 | :return:
22 | """
23 | list_of_lists = []
24 | for glioma_type in ['HGG', 'LGG']:
25 | current_directory = join(base_dir, glioma_type)
26 | patients = subfolders(current_directory, join=False)
27 | for p in patients:
28 | patient_directory = join(current_directory, p)
29 | t1_file = join(patient_directory, p + "_t1.nii.gz")
30 | t1c_file = join(patient_directory, p + "_t1ce.nii.gz")
31 | t2_file = join(patient_directory, p + "_t2.nii.gz")
32 | flair_file = join(patient_directory, p + "_flair.nii.gz")
33 | seg_file = join(patient_directory, p + "_seg.nii.gz")
34 | this_case = [t1_file, t1c_file, t2_file, flair_file, seg_file]
35 | assert all((isfile(i) for i in this_case)), "some file is missing for patient %s; make sure the following " \
36 | "files are there: %s" % (p, str(this_case))
37 | list_of_lists.append(this_case)
38 | print("Found %d patients" % len(list_of_lists))
39 | return list_of_lists
40 |
41 |
42 | def load_and_preprocess(case, patient_name, output_folder):
43 | """
44 | loads, preprocesses and saves a case
45 | This is what happens here:
46 | 1) load all images and stack them to a 4d array
47 | 2) crop to nonzero region, this removes unnecessary zero-valued regions and reduces computation time
48 | 3) normalize the nonzero region with its mean and standard deviation
49 | 4) save 4d tensor as numpy array. Also save metadata required to create niftis again (required for export
50 | of predictions)
51 |
52 | :param case:
53 | :param patient_name:
54 | :return:
55 | """
56 | # load SimpleITK Images
57 | imgs_sitk = [sitk.ReadImage(i) for i in case]
58 |
59 | # get pixel arrays from SimpleITK images
60 | imgs_npy = [sitk.GetArrayFromImage(i) for i in imgs_sitk]
61 |
62 | # get some metadata
63 | spacing = imgs_sitk[0].GetSpacing()
64 | # the spacing returned by SimpleITK is in inverse order relative to the numpy array we receive. If we wanted to
65 | # resample the data and if the spacing was not isotropic (in BraTS all cases have already been resampled to 1x1x1mm
66 | # by the organizers) then we need to pay attention here. Therefore we bring the spacing into the correct order so
67 | # that spacing[0] actually corresponds to the spacing of the first axis of the numpy array
68 | spacing = np.array(spacing)[::-1]
69 |
70 | direction = imgs_sitk[0].GetDirection()
71 | origin = imgs_sitk[0].GetOrigin()
72 |
73 | original_shape = imgs_npy[0].shape
74 |
75 | # now stack the images into one 4d array, cast to float because we will get rounding problems if we don't
76 | imgs_npy = np.concatenate([i[None] for i in imgs_npy]).astype(np.float32)
77 |
78 | # now find the nonzero region and crop to that
79 | nonzero = [np.array(np.where(i != 0)) for i in imgs_npy]
80 | nonzero = [[np.min(i, 1), np.max(i, 1)] for i in nonzero]
81 | nonzero = np.array([np.min([i[0] for i in nonzero], 0), np.max([i[1] for i in nonzero], 0)]).T
82 | # nonzero now has shape 3, 2. It contains the (min, max) coordinate of nonzero voxels for each axis
83 |
84 | # now crop to nonzero
85 | imgs_npy = imgs_npy[:,
86 | nonzero[0, 0] : nonzero[0, 1] + 1,
87 | nonzero[1, 0]: nonzero[1, 1] + 1,
88 | nonzero[2, 0]: nonzero[2, 1] + 1,
89 | ]
90 |
91 | # now we create a brain mask that we use for normalization
92 | nonzero_masks = [i != 0 for i in imgs_npy[:-1]]
93 | brain_mask = np.zeros(imgs_npy.shape[1:], dtype=bool)
94 | for i in range(len(nonzero_masks)):
95 | brain_mask = brain_mask | nonzero_masks[i]
96 |
97 | # now normalize each modality with its mean and standard deviation (computed within the brain mask)
98 | for i in range(len(imgs_npy) - 1):
99 | mean = imgs_npy[i][brain_mask].mean()
100 | std = imgs_npy[i][brain_mask].std()
101 | imgs_npy[i] = (imgs_npy[i] - mean) / (std + 1e-8)
102 | imgs_npy[i][brain_mask == 0] = 0
103 |
104 | # the segmentation of brats has the values 0, 1, 2 and 4. This is pretty inconvenient to say the least.
105 | # We move everything that is 4 to 3
106 | imgs_npy[-1][imgs_npy[-1] == 4] = 3
107 |
108 | # now save as npz
109 | np.save(join(output_folder, patient_name + ".npy"), imgs_npy)
110 |
111 | metadata = {
112 | 'spacing': spacing,
113 | 'direction': direction,
114 | 'origin': origin,
115 | 'original_shape': original_shape,
116 | 'nonzero_region': nonzero
117 | }
118 |
119 | save_pickle(metadata, join(output_folder, patient_name + ".pkl"))
120 |
121 |
122 | def save_segmentation_as_nifti(segmentation, metadata, output_file):
123 | original_shape = metadata['original_shape']
124 | seg_original_shape = np.zeros(original_shape, dtype=np.uint8)
125 | nonzero = metadata['nonzero_region']
126 | seg_original_shape[nonzero[0, 0] : nonzero[0, 1] + 1,
127 | nonzero[1, 0]: nonzero[1, 1] + 1,
128 | nonzero[2, 0]: nonzero[2, 1] + 1] = segmentation
129 | sitk_image = sitk.GetImageFromArray(seg_original_shape)
130 | sitk_image.SetDirection(metadata['direction'])
131 | sitk_image.SetOrigin(metadata['origin'])
132 | # remember to revert spacing back to sitk order again
133 | sitk_image.SetSpacing(tuple(metadata['spacing'][[2, 1, 0]]))
134 | sitk.WriteImage(sitk_image, output_file)
135 |
136 |
137 | if __name__ == "__main__":
138 | # This is the same preprocessing I used for our contributions to the BraTS 2017 and 2018 challenges.
139 | # Preprocessing is described in the documentation of load_and_preprocess
140 |
141 | # The training data is identical between BraTS 2017 and 2018. You can request access here:
142 | # https://ipp.cbica.upenn.edu/#BraTS18_registration
143 |
144 | # brats_base points to where the extracted downloaded training data is
145 |
146 | # preprocessed data is saved as npy. This may seem odd if you are familiar with medical images, but trust me it's
147 | # the best way to do this for deep learning. It does not make much of a difference for BraTS, but if you are
148 | # dealing with larger images this is crusial for your pipelines to not get stuck in CPU bottleneck. What we can do
149 | # with numpy arrays is we can load them via np.load(file, mmap_mode="r") and then read just parts of it on the fly
150 | # during training. This is super important if your patch size is smaller than the size of the entire patient (for
151 | # example if you work with large CT data or if you need 2D slices).
152 | # For this to work properly the output_folder (or wherever the data is stored during training) must be on an SSD!
153 | # HDDs are usually too slow and you also wouldn't want to do this over a network share (there are exceptions but
154 | # take this as a rule of thumb)
155 |
156 | # Why is this not an IPython Notebook you may ask? Because I HATE IPython Notebooks. Simple :-)
157 |
158 | list_of_lists = get_list_of_files(brats_folder_with_downloaded_train_data)
159 |
160 | maybe_mkdir_p(brats_preprocessed_folder)
161 |
162 | patient_names = [i[0].split("/")[-2] for i in list_of_lists]
163 |
164 | p = Pool(processes=num_threads_for_brats_example)
165 | p.starmap(load_and_preprocess, zip(list_of_lists, patient_names, [brats_preprocessed_folder] * len(list_of_lists)))
166 | p.close()
167 | p.join()
168 |
169 | # remember that we cropped the data before preprocessing. If we predict the test cases, we want to run the same
170 | # preprocessing for them. We need to then put the segmentation back into its original position (due to cropping).
171 | # Here is how you can do that:
172 |
173 | # lets use Brats17_2013_0_1 for this example
174 | img = np.load(join(brats_preprocessed_folder, "Brats17_2013_0_1.npy"))
175 | metadata = load_pickle(join(brats_preprocessed_folder, "Brats17_2013_0_1.pkl"))
176 | # remember that we changed the segmentation labels from 0, 1, 2, 4 to 0, 1, 2, 3. We need to change that back to
177 | # get the correct format
178 | img[-1][img[-1] == 3] = 4
179 | save_segmentation_as_nifti(img[-1], metadata, join(brats_preprocessed_folder, "delete_me.nii.gz"))
180 |
--------------------------------------------------------------------------------
/tests/test_color_augmentations.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import unittest
17 | import numpy as np
18 | from batchgenerators.augmentations.color_augmentations import augment_contrast, augment_brightness_additive,\
19 | augment_brightness_multiplicative, augment_gamma
20 |
21 |
22 | class TestAugmentContrast(unittest.TestCase):
23 |
24 | def setUp(self):
25 | np.random.seed(1234)
26 | self.data_3D = np.random.random((2, 64, 56, 48))
27 | self.data_2D = np.random.random((2, 64, 56))
28 | self.factor = (0.75, 1.25)
29 |
30 | self.d_3D = augment_contrast(self.data_3D, contrast_range=self.factor, preserve_range=False, per_channel=False)
31 | self.d_2D = augment_contrast(self.data_2D, contrast_range=self.factor, preserve_range=False, per_channel=False)
32 |
33 | def test_augment_contrast_3D(self):
34 |
35 | mean = np.mean(self.data_3D)
36 |
37 | idx0 = np.where(self.data_3D < mean) # where the data is lower than mean value
38 | idx1 = np.where(self.data_3D > mean) # where the data is greater than mean value
39 |
40 | contrast_lower_limit_0 = self.factor[1] * (self.data_3D[idx0] - mean) + mean
41 | contrast_lower_limit_1 = self.factor[0] * (self.data_3D[idx1] - mean) + mean
42 | contrast_upper_limit_0 = self.factor[0] * (self.data_3D[idx0] - mean) + mean
43 | contrast_upper_limit_1 = self.factor[1] * (self.data_3D[idx1] - mean) + mean
44 |
45 | # augmented values lower than mean should be lower than lower limit and greater than upper limit
46 | self.assertTrue(np.all(np.logical_and(self.d_3D[idx0] >= contrast_lower_limit_0,
47 | self.d_3D[idx0] <= contrast_upper_limit_0)),
48 | "Augmented contrast below mean value not within range")
49 | # augmented values greater than mean should be lower than upper limit and greater than lower limit
50 | self.assertTrue(np.all(np.logical_and(self.d_3D[idx1] >= contrast_lower_limit_1,
51 | self.d_3D[idx1] <= contrast_upper_limit_1)),
52 | "Augmented contrast above mean not within range")
53 |
54 | def test_augment_contrast_2D(self):
55 |
56 | mean = np.mean(self.data_2D)
57 |
58 | idx0 = np.where(self.data_2D < mean) # where the data is lower than mean value
59 | idx1 = np.where(self.data_2D > mean) # where the data is greater than mean value
60 |
61 | contrast_lower_limit_0 = self.factor[1] * (self.data_2D[idx0] - mean) + mean
62 | contrast_lower_limit_1 = self.factor[0] * (self.data_2D[idx1] - mean) + mean
63 | contrast_upper_limit_0 = self.factor[0] * (self.data_2D[idx0] - mean) + mean
64 | contrast_upper_limit_1 = self.factor[1] * (self.data_2D[idx1] - mean) + mean
65 |
66 | # augmented values lower than mean should be lower than lower limit and greater than upper limit
67 | self.assertTrue(np.all(np.logical_and(self.d_2D[idx0] >= contrast_lower_limit_0,
68 | self.d_2D[idx0] <= contrast_upper_limit_0)),
69 | "Augmented contrast below mean value not within range")
70 | # augmented values greater than mean should be lower than upper limit and greater than lower limit
71 | self.assertTrue(np.all(np.logical_and(self.d_2D[idx1] >= contrast_lower_limit_1,
72 | self.d_2D[idx1] <= contrast_upper_limit_1)),
73 | "Augmented contrast above mean not within range")
74 |
75 |
76 | class TestAugmentBrightness(unittest.TestCase):
77 |
78 | def setUp(self):
79 | np.random.seed(1234)
80 | self.data_input_3D = np.random.random((2, 64, 56, 48))
81 | self.data_input_2D = np.random.random((2, 64, 56))
82 | self.factor = (0.75, 1.25)
83 | self.multiplier_range = [2,4]
84 |
85 | self.d_3D_per_channel = augment_brightness_additive(np.copy(self.data_input_3D), mu=100, sigma=10,
86 | per_channel=True)
87 | self.d_3D = augment_brightness_additive(np.copy(self.data_input_3D), mu=100, sigma=10, per_channel=False)
88 |
89 | self.d_2D_per_channel = augment_brightness_additive(np.copy(self.data_input_2D), mu=100, sigma=10,
90 | per_channel=True)
91 | self.d_2D = augment_brightness_additive(np.copy(self.data_input_2D), mu=100, sigma=10, per_channel=False)
92 |
93 | self.d_3D_per_channel_mult = augment_brightness_multiplicative(np.copy(self.data_input_3D),
94 | multiplier_range=self.multiplier_range,
95 | per_channel=True)
96 | self.d_3D_mult = augment_brightness_multiplicative(np.copy(self.data_input_3D),
97 | multiplier_range=self.multiplier_range, per_channel=False)
98 |
99 | self.d_2D_per_channel_mult = augment_brightness_multiplicative(np.copy(self.data_input_2D),
100 | multiplier_range=self.multiplier_range,
101 | per_channel=True)
102 | self.d_2D_mult = augment_brightness_multiplicative(np.copy(self.data_input_2D),
103 | multiplier_range=self.multiplier_range, per_channel=False)
104 |
105 | def test_augment_brightness_additive_3D(self):
106 | add_factor = self.d_3D-self.data_input_3D
107 | self.assertTrue(len(np.unique(add_factor.round(decimals=8)))==1,
108 | "Added brightness factor is not equal for all channels")
109 |
110 | add_factor = self.d_3D_per_channel - self.data_input_3D
111 | self.assertTrue(len(np.unique(add_factor.round(decimals=8))) == self.data_input_3D.shape[0],
112 | "Added brightness factor is not different for each channels")
113 |
114 | def test_augment_brightness_additive_2D(self):
115 | add_factor = self.d_2D-self.data_input_2D
116 | self.assertTrue(len(np.unique(add_factor.round(decimals=8)))==1,
117 | "Added brightness factor is not equal for all channels")
118 |
119 | add_factor = self.d_2D_per_channel - self.data_input_2D
120 | self.assertTrue(len(np.unique(add_factor.round(decimals=8))) == self.data_input_2D.shape[0],
121 | "Added brightness factor is not different for each channels")
122 |
123 | def test_augment_brightness_multiplicative_3D(self):
124 | mult_factor = self.d_3D_mult/self.data_input_3D
125 | self.assertTrue(len(np.unique(mult_factor.round(decimals=6)))==1,
126 | "Multiplied brightness factor is not equal for all channels")
127 |
128 | mult_factor = self.d_3D_per_channel_mult/self.data_input_3D
129 | self.assertTrue(len(np.unique(mult_factor.round(decimals=6))) == self.data_input_3D.shape[0],
130 | "Multiplied brightness factor is not different for each channels")
131 |
132 | def test_augment_brightness_multiplicative_2D(self):
133 | mult_factor = self.d_2D_mult/self.data_input_2D
134 | self.assertTrue(len(np.unique(mult_factor.round(decimals=6)))==1,
135 | "Multiplied brightness factor is not equal for all channels")
136 |
137 | mult_factor = self.d_2D_per_channel_mult/self.data_input_2D
138 | self.assertTrue(len(np.unique(mult_factor.round(decimals=6))) == self.data_input_2D.shape[0],
139 | "Multiplied brightness factor is not different for each channels")
140 |
141 |
142 | class TestAugmentGamma(unittest.TestCase):
143 |
144 | def setUp(self):
145 | np.random.seed(1234)
146 | self.data_input_3D = np.random.random((2, 64, 56, 48))
147 | self.data_input_2D = np.random.random((2, 64, 56))
148 |
149 | self.d_3D = augment_gamma(np.copy(self.data_input_2D), gamma_range=(0.2, 1.2), per_channel=False)
150 |
151 | def test_augment_gamma_3D(self):
152 | self.assertTrue(self.d_3D.min().round(decimals=3) == self.data_input_3D.min().round(decimals=3) and
153 | self.d_3D.max().round(decimals=3) == self.data_input_3D.max().round(decimals=3),
154 | "Input range does not equal output range")
155 |
156 |
157 | if __name__ == '__main__':
158 | unittest.main()
159 |
--------------------------------------------------------------------------------
/batchgenerators/transforms/crop_and_pad_transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from batchgenerators.augmentations.crop_and_pad_augmentations import center_crop, pad_nd_image_and_seg, random_crop
17 | from batchgenerators.transforms.abstract_transforms import AbstractTransform
18 | import numpy as np
19 |
20 |
21 | class CenterCropTransform(AbstractTransform):
22 | """ Crops data and seg (if available) in the center
23 |
24 | Args:
25 | output_size (int or tuple of int): Output patch size
26 |
27 | """
28 |
29 | def __init__(self, crop_size, data_key="data", label_key="seg"):
30 | self.data_key = data_key
31 | self.label_key = label_key
32 | self.crop_size = crop_size
33 |
34 | def __call__(self, **data_dict):
35 | data = data_dict.get(self.data_key)
36 | seg = data_dict.get(self.label_key)
37 | data, seg = center_crop(data, self.crop_size, seg)
38 |
39 | data_dict[self.data_key] = data
40 | if seg is not None:
41 | data_dict[self.label_key] = seg
42 |
43 | return data_dict
44 |
45 |
46 | class CenterCropSegTransform(AbstractTransform):
47 | """ Crops seg in the center (required if you are using unpadded convolutions in a segmentation network).
48 | Leaves data as it is
49 |
50 | Args:
51 | output_size (int or tuple of int): Output patch size
52 |
53 | """
54 |
55 | def __init__(self, output_size, data_key="data", label_key="seg"):
56 | self.data_key = data_key
57 | self.label_key = label_key
58 | self.output_size = output_size
59 |
60 | def __call__(self, **data_dict):
61 | seg = data_dict.get(self.label_key)
62 |
63 | if seg is not None:
64 | data_dict[self.label_key] = center_crop(seg, self.output_size, None)[0]
65 | else:
66 | from warnings import warn
67 | warn("You shall not pass data_dict without seg: Used CenterCropSegTransform, but there is no seg", Warning)
68 | return data_dict
69 |
70 |
71 | class RandomCropTransform(AbstractTransform):
72 | """ Randomly crops data and seg (if available)
73 |
74 | Args:
75 | crop_size (int or tuple of int): Output patch size
76 |
77 | margins (tuple of int): how much distance should the patch border have to the image broder (bilaterally)?
78 |
79 | """
80 |
81 | def __init__(self, crop_size=128, margins=(0, 0, 0), data_key="data", label_key="seg"):
82 | self.data_key = data_key
83 | self.label_key = label_key
84 | self.margins = margins
85 | self.crop_size = crop_size
86 |
87 | def __call__(self, **data_dict):
88 | data = data_dict.get(self.data_key)
89 | seg = data_dict.get(self.label_key)
90 |
91 | data, seg = random_crop(data, seg, self.crop_size, self.margins)
92 |
93 | data_dict[self.data_key] = data
94 | if seg is not None:
95 | data_dict[self.label_key] = seg
96 |
97 | return data_dict
98 |
99 |
100 | class PadTransform(AbstractTransform):
101 | def __init__(self, new_size, pad_mode_data='constant', pad_mode_seg='constant',
102 | np_pad_kwargs_data=None, np_pad_kwargs_seg=None,
103 | data_key="data", label_key="seg"):
104 | """
105 | Pads data and seg to new_size. Only supports numpy arrays for data and seg.
106 |
107 | :param new_size: (x, y(, z))
108 | :param pad_value_data:
109 | :param pad_value_seg:
110 | :param data_key:
111 | :param label_key:
112 | """
113 | self.data_key = data_key
114 | self.label_key = label_key
115 | self.new_size = new_size
116 | self.pad_mode_data = pad_mode_data
117 | self.pad_mode_seg = pad_mode_seg
118 | if np_pad_kwargs_data is None:
119 | np_pad_kwargs_data = {}
120 | if np_pad_kwargs_seg is None:
121 | np_pad_kwargs_seg = {}
122 | self.np_pad_kwargs_data = np_pad_kwargs_data
123 | self.np_pad_kwargs_seg = np_pad_kwargs_seg
124 |
125 | assert isinstance(self.new_size, (tuple, list, np.ndarray)), "new_size must be tuple, list or np.ndarray"
126 |
127 | def __call__(self, **data_dict):
128 | data = data_dict.get(self.data_key)
129 | seg = data_dict.get(self.label_key)
130 |
131 | assert len(self.new_size) + 2 == len(data.shape), "new size must be a tuple/list/np.ndarray with shape " \
132 | "(x, y(, z))"
133 | data, seg = pad_nd_image_and_seg(data, seg, self.new_size, None,
134 | np_pad_kwargs_data=self.np_pad_kwargs_data,
135 | np_pad_kwargs_seg=self.np_pad_kwargs_seg,
136 | pad_mode_data=self.pad_mode_data,
137 | pad_mode_seg=self.pad_mode_seg)
138 |
139 | data_dict[self.data_key] = data
140 | if seg is not None:
141 | data_dict[self.label_key] = seg
142 |
143 | return data_dict
144 |
145 |
146 | class RandomShiftTransform(AbstractTransform):
147 | def __init__(self, shift_mu, shift_sigma, p_per_sample=1, p_per_channel=0.5, border_value=0, apply_to_keys=('data',)):
148 | """
149 | randomly shifts the data by some amount. Equivalent to pad -> random crop but with (probably) less
150 | computational requirements
151 |
152 | shift_mu gives the mean value of the shift, 0 is recommended
153 | shift_sigma gives the standard deviation of the shift
154 |
155 | shift will ne drawn from a Gaussian distribution with mean shift_mu and variance shift_sigma
156 |
157 | shift_mu and shift_sigma can either be float values OR tuples of float values. If they are tuples they will
158 | be interpreted as separate mean and std for each dimension
159 |
160 | TODO separate per channel or not?
161 |
162 | :param shift_mu:
163 | :param shift_sigma:
164 | :param p_per_sample:
165 | :param p_per_channel:
166 | :param apply_to_keys:
167 | """
168 | self.apply_to_keys = apply_to_keys
169 | self.p_per_channel = p_per_channel
170 | self.p_per_sample = p_per_sample
171 | self.shift_sigma = shift_sigma
172 | self.shift_mu = shift_mu
173 | self.border_value = border_value
174 |
175 | def __call__(self, **data_dict):
176 | for k in self.apply_to_keys:
177 | workon = data_dict[k]
178 | for b in range(workon.shape[0]):
179 | if np.random.uniform(0, 1) < self.p_per_sample:
180 | for c in range(workon.shape[1]):
181 | if np.random.uniform(0, 1) < self.p_per_channel:
182 | shift_here = []
183 | for d in range(len(workon.shape) - 2):
184 | shift_here.append(int(np.round(np.random.normal(
185 | self.shift_mu[d] if isinstance(self.shift_mu, (list, tuple)) else self.shift_mu,
186 | self.shift_sigma[d] if isinstance(self.shift_sigma,
187 | (list, tuple)) else self.shift_sigma,
188 | size=1))))
189 | data_copy = np.ones_like(workon[b, c]) * self.border_value
190 | lb_x = max(shift_here[0], 0)
191 | ub_x = max(0, min(workon.shape[2], workon.shape[2] + shift_here[0]))
192 | lb_y = max(shift_here[1], 0)
193 | ub_y = max(0, min(workon.shape[3], workon.shape[3] + shift_here[1]))
194 |
195 | t_lb_x = max(-shift_here[0], 0)
196 | t_ub_x = max(0, min(workon.shape[2], workon.shape[2] - shift_here[0]))
197 | t_lb_y = max(-shift_here[1], 0)
198 | t_ub_y = max(0, min(workon.shape[3], workon.shape[3] - shift_here[1]))
199 |
200 | if len(shift_here) == 2:
201 | data_copy[t_lb_x:t_ub_x, t_lb_y:t_ub_y] = workon[b, c, lb_x:ub_x, lb_y:ub_y]
202 | elif len(shift_here) == 3:
203 | lb_z = max(shift_here[2], 0)
204 | ub_z = max(0, min(workon.shape[4], workon.shape[4] + shift_here[2]))
205 | t_lb_z = max(-shift_here[2], 0)
206 | t_ub_z = max(0, min(workon.shape[2], workon.shape[4] - shift_here[2]))
207 | data_copy[t_lb_x:t_ub_x, t_lb_y:t_ub_y, t_lb_z:t_ub_z] = workon[b, c, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z]
208 | data_dict[k][b, c] = data_copy
209 | return data_dict
--------------------------------------------------------------------------------
/Readme.md:
--------------------------------------------------------------------------------
1 | # batchgenerators by MIC@DKFZ
2 |
3 | Copyright German Cancer Research Center (DKFZ) and contributors.
4 | Please make sure that your usage of this code is in compliance with its
5 | [`license`](https://github.com/MIC-DKFZ/batchgenerators/blob/master/LICENSE).
6 |
7 | batchgenerators is a python package for data augmentation. It is developed jointly between the Division of
8 | Medical Image Computing at the German Cancer Research Center (DKFZ) and the Applied Computer
9 | Vision Lab of the Helmholtz Imaging Platform.
10 |
11 | It is not (yet) perfect, but we feel it is good enough to be shared with the community. If you encounter bug, feel free
12 | to contact us or open a github issue.
13 |
14 | If you use it please cite the following work:
15 | ```
16 | Isensee Fabian, Jäger Paul, Wasserthal Jakob, Zimmerer David, Petersen Jens, Kohl Simon,
17 | Schock Justus, Klein Andre, Roß Tobias, Wirkert Sebastian, Neher Peter, Dinkelacker Stefan,
18 | Köhler Gregor, Maier-Hein Klaus (2020). batchgenerators - a python framework for data
19 | augmentation. doi:10.5281/zenodo.3632567
20 | ```
21 |
22 | batchgenerators also contains the following application-specific augmentations:
23 | * **Anatomy-informed Data Augmentation**
24 | Proposed at [MICCAI 2023](https://arxiv.org/abs/2309.03652) for simulation of soft-tissue deformations. Implementation details can be found [here](https://github.com/MIC-DKFZ/anatomy_informed_DA).
25 | * **Misalignment Data Augmentation**
26 | Proposed in [Nature Scientific Reports 2023](https://www.nature.com/articles/s41598-023-46747-z)
27 | for enhancing model's adaptability to diverse misalignments\
28 | between multi-modal (multi-channel) images and thereby ensuring robust performance. Implementation details can be found [here](https://github.com/MIC-DKFZ/misalignment_DA).
29 |
30 | If you use these augmentations please cite them too.
31 |
32 | [](https://travis-ci.com/github/MIC-DKFZ/batchgenerators)
33 |
34 | ## Supported Augmentations
35 | We supports a variety of augmentations, all of which are compatible with **2D and 3D input data**! (This is something
36 | that was missing in most other frameworks).
37 |
38 | * **Spatial Augmentations**
39 | * mirroring
40 | * channel translation (to simulate registration errors)
41 | * elastic deformations
42 | * rotations
43 | * scaling
44 | * resampling
45 | * multi-channel misalignments
46 | * **Color Augmentations**
47 | * brightness (additive, multiplivative)
48 | * contrast
49 | * gamma (like gamma correction in photo editing)
50 | * **Noise Augmentations**
51 | * Gaussian Noise
52 | * Rician Noise
53 | * ...will be expanded in future commits
54 | * **Cropping**
55 | * random crop
56 | * center crop
57 | * padding
58 | * **Anatomy-informed Augmentation**
59 |
60 | Note: Stack transforms by using batchgenerators.transforms.abstract_transforms.Compose. Finish it up by plugging the
61 | composed transform into our **multithreader**: batchgenerators.dataloading.multi_threaded_augmenter.MultiThreadedAugmenter
62 |
63 |
64 | ## How to use it
65 |
66 | The working principle is simple: Derive from DataLoaderBase class, reimplement generate_train_batch member function and
67 | use it to stack your augmentations!
68 | For simple example see `batchgenerators/examples/example_ipynb.ipynb`
69 |
70 | A heavily commented example for using SlimDataLoaderBase and MultithreadedAugmentor is available at:
71 | `batchgenerators/examples/multithreaded_with_batches.ipynb`.
72 | It gives an idea of the interplay between the SlimDataLoaderBase and the MultiThreadedAugmentor.
73 | The example uses the MultiThreadedAugmentor for loading and augmentation on mutiple processes, while
74 | covering the entire dataset only once per epoch (basically sampling without replacement).
75 |
76 | We also now have an extensive example for BraTS2017/2018 with both 2D and 3D DataLoader and augmentations:
77 | `batchgenerators/examples/brats2017/`
78 |
79 | There are also CIFAR10/100 datasets and DataLoader available at `batchgenerators/datasets/cifar.py`
80 |
81 | ## Data Structure
82 |
83 | The data structure that is used internally (and with which you have to comply when implementing generate_train_batch)
84 | is kept simple as well: It is just a regular python dictionary! We did this to allow maximum flexibility in the kind of
85 | data that is passed along through the pipeline. The dictionary must have a 'data' key:value pair. It optionally can
86 | handle a 'seg' key:vlaue pair to hold a segmentation. If a 'seg' key:value pair is present all spatial transformations
87 | will also be applied to the segmentation! A part from 'data' and 'seg' you are free to do whatever you want (your image
88 | classification/regression target for example). All key:value pairs other than 'data' and 'seg' will be passed through the
89 | pipeline unmodified.
90 |
91 | 'data' value must have shape (b, c, x, y) for 2D or shape (b, c, x, y, z) for 3D!
92 | 'seg' value must have shape (b, c, x, y) for 2D or shape (b, c, x, y, z) for 3D! Color channel may be used here to
93 | allow for several segmentation maps. If you have only one segmentation, make sure to have shape (b, 1, x, y (, z))
94 |
95 | ## How to install locally
96 |
97 | Install batchgenerators
98 | ```
99 | pip install --upgrade batchgenerators
100 | ```
101 |
102 | Import as follows
103 | ```
104 | from batchgenerators.transforms.color_transforms import ContrastAugmentationTransform
105 | ```
106 |
107 | ## Windows Support is very experimental!
108 | Batchgenerators makes heavy use of python multiprocessing and python multiprocessing on windows is different from linux.
109 | To prevent the workers from freezing in windows, you have to guard your code with `if __name__ == '__main__'` and use multiprocessing's [`freeze_support`](https://docs.python.org/3/library/multiprocessing.html#multiprocessing.freeze_support). The executed script may then look like this:
110 |
111 | ```
112 | # some imports and functions here
113 |
114 | def main():
115 | # do some stuff
116 |
117 | if __name__ == '__main__':
118 | from multiprocessing import freeze_support
119 | freeze_support()
120 | main()
121 | ```
122 |
123 | This is not required on Linux.
124 |
125 |
126 | ## Release Notes
127 | (only highlights, not an exhaustive list)
128 | - 0.23.2:
129 | - Misalignment data augmentation added
130 | - 0.23.1:
131 | - Anatomy-informed data augmentation added
132 | - 0.23:
133 | - fixed the import mess. `__init__.py` files are now empty. This is a breaking change for some users!
134 | Please adapt your imports :-)
135 | - local_transforms are now a thing, check them out!
136 | - resize_segmentation now uses 'edge' mode and no longer takes a cval argument. Resizing segmentations with constant
137 | border values (previous default) can cause problems and should not be done.
138 | - 0.20.0:
139 | - fixed an issue with MultiThreadedAugmenter not terminating properly after KeyboardInterrupt; Fixed an error
140 | with the number and order of samples being returned when pin_memory=True; Improved performance by always hiding
141 | process-process communication bottleneck through threading
142 | - 0.19.5:
143 | - fixed OMP_NUM_THREADS issue by using threadpoolctl package; dropped python 2 support (threadpoolctl is not
144 | available for python 2)
145 | - 0.19:
146 | - There is now a complete example for BraTS2017/8 available for both 2D and 3D. Use this if you would like to get
147 | some insights on how I (Fabian) do my experiments
148 | - Windows is now supported! Thanks @justusschock for your support!
149 | - new, simple parametrization of elastic deformation. Use SpatialTransform_2!
150 | - CIFAR10/100 DataLoader are now available for your convenience
151 | - a bug in MultiThreadedAugmenter that could interfere with reproducibility is now fixed
152 |
153 | - 0.18:
154 | - all augmentations (there are some exceptions though) are implemented on a per-sample basis. This should make it
155 | easier to use the augmentations outside of the Transforms of batchgenerators
156 | - applicable Transforms now have a keyword p_per_sample with which the user can specify a probability with which this
157 | transform is applied to a sample. Before, this was handled by RndTransform and applied to the whole batch (so
158 | either all samples were augmented or none). Now this decision is made on a per-sample basis and increases
159 | variability by a lot.
160 | - following the previous point, RndTransform is now deprecated
161 | - AlternativeMultiThreadedAugmenter is now deprecated as well (no need to have this anymore)
162 | - pytorch users can now transform numpy arrays to pytorch tensors within batchgenerators (NumpyToTensor). For some
163 | reason, inter-process communication is faster with tensors (~factor 4), so this is recommended!
164 | - if numpy arrays were converted to pytorch tensors, MultithreadedAugmenter now allows to pin the memory as well
165 | (pin_memory=True). This will happen in a background thread (inspired by pytorch DataLoader). pinned memory can be
166 | copied to the GPU much faster. My (Fabian) classification experiment with Resnet50 got a speed boost of 12% from just
167 | that.
168 |
169 |
170 | -------------------------
171 |
172 |
173 |
174 |
175 |
176 | batchgenerators is developed by the [Division of Medical Image Computing](https://www.dkfz.de/en/mic/index.php) of the
177 | German Cancer Research Center (DKFZ) and the Applied Computer Vision Lab (ACVL) of the
178 | [Helmholtz Imaging Platform](https://helmholtz-imaging.de).
179 |
--------------------------------------------------------------------------------
/batchgenerators/transforms/color_transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from typing import Union, Tuple, Callable
17 |
18 | import numpy as np
19 |
20 | from batchgenerators.augmentations.color_augmentations import augment_contrast, augment_brightness_additive, \
21 | augment_brightness_multiplicative, augment_gamma, augment_illumination, augment_PCA_shift
22 | from batchgenerators.transforms.abstract_transforms import AbstractTransform
23 |
24 |
25 | class ContrastAugmentationTransform(AbstractTransform):
26 | def __init__(self,
27 | contrast_range: Union[Tuple[float, float], Callable[[], float]] = (0.75, 1.25),
28 | preserve_range: bool = True,
29 | per_channel: bool = True,
30 | data_key: str = "data",
31 | p_per_sample: float = 1,
32 | p_per_channel: float = 1):
33 | """
34 | Augments the contrast of data
35 | :param contrast_range:
36 | (float, float): range from which to sample a random contrast that is applied to the data. If
37 | one value is smaller and one is larger than 1, half of the contrast modifiers will be >1
38 | and the other half <1 (in the inverval that was specified)
39 | callable : must be contrast_range() -> float
40 | :param preserve_range: if True then the intensity values after contrast augmentation will be cropped to min and
41 | max values of the data before augmentation.
42 | :param per_channel: whether to use the same contrast modifier for all color channels or a separate one for each
43 | channel
44 | :param data_key:
45 | :param p_per_sample:
46 | """
47 | self.p_per_sample = p_per_sample
48 | self.data_key = data_key
49 | self.contrast_range = contrast_range
50 | self.preserve_range = preserve_range
51 | self.per_channel = per_channel
52 | self.p_per_channel = p_per_channel
53 |
54 | def __call__(self, **data_dict):
55 | for b in range(len(data_dict[self.data_key])):
56 | if np.random.uniform() < self.p_per_sample:
57 | data_dict[self.data_key][b] = augment_contrast(data_dict[self.data_key][b],
58 | contrast_range=self.contrast_range,
59 | preserve_range=self.preserve_range,
60 | per_channel=self.per_channel,
61 | p_per_channel=self.p_per_channel)
62 | return data_dict
63 |
64 |
65 | class NormalizeTransform(AbstractTransform):
66 | def __init__(self, means, stds, data_key='data'):
67 | self.data_key = data_key
68 | self.stds = stds
69 | self.means = means
70 |
71 | def __call__(self, **data_dict):
72 | for c in range(data_dict[self.data_key].shape[1]):
73 | data_dict[self.data_key][:, c] -= self.means[c]
74 | data_dict[self.data_key][:, c] /= self.stds[c]
75 | return data_dict
76 |
77 |
78 | class BrightnessTransform(AbstractTransform):
79 | def __init__(self, mu, sigma, per_channel=True, data_key="data", p_per_sample=1, p_per_channel=1):
80 | """
81 | Augments the brightness of data. Additive brightness is sampled from Gaussian distribution with mu and sigma
82 | :param mu: mean of the Gaussian distribution to sample the added brightness from
83 | :param sigma: standard deviation of the Gaussian distribution to sample the added brightness from
84 | :param per_channel: whether to use the same brightness modifier for all color channels or a separate one for
85 | each channel
86 | :param data_key:
87 | :param p_per_sample:
88 | """
89 | self.p_per_sample = p_per_sample
90 | self.data_key = data_key
91 | self.mu = mu
92 | self.sigma = sigma
93 | self.per_channel = per_channel
94 | self.p_per_channel = p_per_channel
95 |
96 | def __call__(self, **data_dict):
97 | data = data_dict[self.data_key]
98 |
99 | for b in range(data.shape[0]):
100 | if np.random.uniform() < self.p_per_sample:
101 | data[b] = augment_brightness_additive(data[b], self.mu, self.sigma, self.per_channel,
102 | p_per_channel=self.p_per_channel)
103 |
104 | data_dict[self.data_key] = data
105 | return data_dict
106 |
107 |
108 | class BrightnessMultiplicativeTransform(AbstractTransform):
109 | def __init__(self, multiplier_range=(0.5, 2), per_channel=True, data_key="data", p_per_sample=1):
110 | """
111 | Augments the brightness of data. Multiplicative brightness is sampled from multiplier_range
112 | :param multiplier_range: range to uniformly sample the brightness modifier from
113 | :param per_channel: whether to use the same brightness modifier for all color channels or a separate one for
114 | each channel
115 | :param data_key:
116 | :param p_per_sample:
117 | """
118 | self.p_per_sample = p_per_sample
119 | self.data_key = data_key
120 | self.multiplier_range = multiplier_range
121 | self.per_channel = per_channel
122 |
123 | def __call__(self, **data_dict):
124 | for b in range(len(data_dict[self.data_key])):
125 | if np.random.uniform() < self.p_per_sample:
126 | data_dict[self.data_key][b] = augment_brightness_multiplicative(data_dict[self.data_key][b],
127 | self.multiplier_range,
128 | self.per_channel)
129 | return data_dict
130 |
131 |
132 | class GammaTransform(AbstractTransform):
133 | def __init__(self, gamma_range=(0.5, 2), invert_image=False, per_channel=False, data_key="data",
134 | retain_stats: Union[bool, Callable[[], bool]] = False, p_per_sample=1):
135 | """
136 | Augments by changing 'gamma' of the image (same as gamma correction in photos or computer monitors
137 |
138 | :param gamma_range: range to sample gamma from. If one value is smaller than 1 and the other one is
139 | larger then half the samples will have gamma <1 and the other >1 (in the inverval that was specified).
140 | Tuple of float. If one value is < 1 and the other > 1 then half the images will be augmented with gamma values
141 | smaller than 1 and the other half with > 1
142 | :param invert_image: whether to invert the image before applying gamma augmentation
143 | :param per_channel:
144 | :param data_key:
145 | :param retain_stats: Gamma transformation will alter the mean and std of the data in the patch. If retain_stats=True,
146 | the data will be transformed to match the mean and standard deviation before gamma augmentation. retain_stats
147 | can also be callable (signature retain_stats() -> bool)
148 | :param p_per_sample:
149 | """
150 | self.p_per_sample = p_per_sample
151 | self.retain_stats = retain_stats
152 | self.per_channel = per_channel
153 | self.data_key = data_key
154 | self.gamma_range = gamma_range
155 | self.invert_image = invert_image
156 |
157 | def __call__(self, **data_dict):
158 | for b in range(len(data_dict[self.data_key])):
159 | if np.random.uniform() < self.p_per_sample:
160 | data_dict[self.data_key][b] = augment_gamma(data_dict[self.data_key][b], self.gamma_range,
161 | self.invert_image,
162 | per_channel=self.per_channel,
163 | retain_stats=self.retain_stats)
164 | return data_dict
165 |
166 |
167 | class IlluminationTransform(AbstractTransform):
168 | """Do not use this for now"""
169 |
170 | def __init__(self, white_rgb, data_key="data"):
171 | self.data_key = data_key
172 | self.white_rgb = white_rgb
173 |
174 | def __call__(self, **data_dict):
175 | data_dict[self.data_key] = augment_illumination(data_dict[self.data_key], self.white_rgb)
176 | return data_dict
177 |
178 |
179 | class FancyColorTransform(AbstractTransform):
180 | """Do not use this for now"""
181 |
182 | def __init__(self, U, s, sigma=0.2, data_key="data"):
183 | self.data_key = data_key
184 | self.s = s
185 | self.U = U
186 | self.sigma = sigma
187 |
188 | def __call__(self, **data_dict):
189 | data_dict[self.data_key] = augment_PCA_shift(data_dict[self.data_key], self.U, self.s, self.sigma)
190 | return data_dict
191 |
192 |
193 | class ClipValueRange(AbstractTransform):
194 | def __init__(self, min=None, max=None, data_key="data"):
195 | """
196 | Clips the value range of data to [min, max]
197 | :param min:
198 | :param max:
199 | :param data_key:
200 | """
201 | self.data_key = data_key
202 | self.min = min
203 | self.max = max
204 |
205 | def __call__(self, **data_dict):
206 | data_dict[self.data_key] = np.clip(data_dict[self.data_key], self.min, self.max)
207 | return data_dict
208 |
--------------------------------------------------------------------------------
/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import traceback
17 | from copy import deepcopy
18 | from typing import List, Union, Callable
19 | import threading
20 | from builtins import range
21 | from multiprocessing import Process
22 | from multiprocessing import Queue
23 | from queue import Queue as thrQueue
24 | import numpy as np
25 | import logging
26 | from multiprocessing import Event
27 | from time import sleep, time
28 |
29 | from batchgenerators.dataloading.data_loader import DataLoader
30 | from threadpoolctl import threadpool_limits
31 |
32 | try:
33 | import torch
34 | except ImportError:
35 | torch = None
36 |
37 |
38 | def producer(queue: Queue, data_loader, transform, thread_id: int, seed,
39 | abort_event: Event, wait_time: float = 0.02):
40 | # the producer will set the abort event if something happens
41 | with threadpool_limits(1, None):
42 | np.random.seed(seed)
43 | data_loader.set_thread_id(thread_id)
44 | item = None
45 |
46 | try:
47 | while True:
48 |
49 | if abort_event.is_set():
50 | return
51 | else:
52 | if item is None:
53 | item = next(data_loader)
54 | if transform is not None:
55 | item = transform(**item)
56 |
57 | if abort_event.is_set():
58 | return
59 |
60 | if not queue.full():
61 | queue.put(item)
62 | item = None
63 | else:
64 | sleep(wait_time)
65 |
66 | except KeyboardInterrupt:
67 | abort_event.set()
68 | return
69 |
70 | except Exception as e:
71 | print("Exception in background worker %d:\n" % thread_id, e)
72 | traceback.print_exc()
73 | abort_event.set()
74 | return
75 |
76 |
77 | def pin_memory_of_all_eligible_items_in_dict(result_dict):
78 | for k in result_dict.keys():
79 | if isinstance(result_dict[k], torch.Tensor):
80 | result_dict[k] = result_dict[k].pin_memory()
81 | return result_dict
82 |
83 |
84 | def results_loop(in_queue: Queue, out_queue: thrQueue, abort_event: Event,
85 | pin_memory: bool, worker_list: List[Process],
86 | gpu: Union[int, None] = None, wait_time: float = 0.02):
87 | do_pin_memory = torch is not None and pin_memory and gpu is not None and torch.cuda.is_available()
88 |
89 | if do_pin_memory:
90 | print('using pin_memory on device', gpu)
91 | torch.cuda.set_device(gpu)
92 |
93 | item = None
94 |
95 | while True:
96 | try:
97 | if abort_event.is_set():
98 | return
99 |
100 | # check if all workers are still alive
101 | if not all([i.is_alive() for i in worker_list]):
102 | abort_event.set()
103 | raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the "
104 | "print statements above for the actual error message")
105 |
106 | if item is None:
107 | if not in_queue.empty():
108 | item = in_queue.get()
109 | if do_pin_memory:
110 | item = pin_memory_of_all_eligible_items_in_dict(item)
111 | else:
112 | sleep(wait_time)
113 | continue
114 |
115 | # we only arrive here if item is not None. Now put item in to the out_queue
116 | if not out_queue.full():
117 | out_queue.put(item)
118 | item = None
119 | else:
120 | sleep(wait_time)
121 | continue
122 |
123 | except Exception as e:
124 | abort_event.set()
125 | raise e
126 |
127 |
128 | class NonDetMultiThreadedAugmenter(object):
129 | """
130 | Non-deterministic but potentially faster than MultiThreadedAugmenter and uses less RAM. Also less complicated.
131 | This one only has one queue through which the communication with background workers happens, meaning that there
132 | can be a race condition to it (and thus a nondeterministic ordering of batches). The advantage of this approach is
133 | that we will never run into the issue where everything needs to wait for worker X to finish its work.
134 | Also this approach requires less RAM because we do not need to have some number of cached batches per worker and
135 | now use a global pool of caches batches that is shared among all workers.
136 | THIS MTA ONLY WORKS WITH DATALOADER THAT RETURN INFINITE RANDOM SAMPLES! So if you are using DataLoader, make sure
137 | to set infinite=True.
138 | Seeding this is not recommended :-)
139 | """
140 |
141 | def __init__(self, data_loader, transform, num_processes, num_cached=2, seeds=None, pin_memory=False,
142 | wait_time=0.02, results_loop_fn: Callable = results_loop):
143 | self.pin_memory = pin_memory
144 | self.transform = transform
145 | self.num_cached = num_cached
146 |
147 | if isinstance(data_loader, DataLoader): assert data_loader.infinite, "Only use DataLoader instances that" \
148 | " have infinite=True"
149 | self.generator = data_loader
150 | self.num_processes = num_processes
151 |
152 | self._queue = None
153 | self._processes = []
154 | self.results_loop_fn = results_loop
155 | self.results_loop_thread = None
156 | self.results_loop_queue = None
157 | self.abort_event = None
158 | self.initialized = False
159 |
160 | self.wait_time = wait_time
161 |
162 | if seeds is not None:
163 | assert len(seeds) == num_processes
164 | else:
165 | seeds = [None] * num_processes
166 | self.seeds = seeds
167 |
168 | def __iter__(self):
169 | return self
170 |
171 | def next(self):
172 | return self.__next__()
173 |
174 | def __get_next_item(self):
175 | item = None
176 |
177 | while item is None:
178 | #
179 | if self.abort_event.is_set():
180 | # self.communication_thread handles checking for dead workers and will set the abort event if necessary
181 | self._finish()
182 | raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the "
183 | "print statements above for the actual error message")
184 |
185 | if not self.results_loop_queue.empty():
186 | item = self.results_loop_queue.get()
187 | self.results_loop_queue.task_done()
188 | else:
189 | sleep(self.wait_time)
190 |
191 | return item
192 |
193 | def __next__(self):
194 | if not self.initialized:
195 | self._start()
196 |
197 | item = self.__get_next_item()
198 | return item
199 |
200 | def _start(self):
201 | if not self.initialized:
202 | self._finish()
203 |
204 | self._queue = Queue(self.num_cached)
205 | self.results_loop_queue = thrQueue(self.num_cached)
206 | self.abort_event = Event()
207 |
208 | logging.debug("starting workers")
209 | if isinstance(self.generator, DataLoader):
210 | self.generator.was_initialized = False
211 |
212 | if torch is not None:
213 | torch_nthreads = torch.get_num_threads()
214 | torch.set_num_threads(1)
215 | with threadpool_limits(limits=1, user_api=None):
216 | for i in range(self.num_processes):
217 | self._processes.append(Process(target=producer, args=(
218 | self._queue, self.generator, self.transform, i, self.seeds[i], self.abort_event, self.wait_time
219 | )))
220 | self._processes[-1].daemon = True
221 | _ = [i.start() for i in self._processes]
222 | if torch is not None:
223 | torch.set_num_threads(torch_nthreads)
224 |
225 | if torch is not None and torch.cuda.is_available():
226 | gpu = torch.cuda.current_device()
227 | else:
228 | gpu = None
229 |
230 | # in_queue: Queue, out_queue: thrQueue, abort_event: Event, pin_memory: bool, worker_list: List[Process],
231 | # gpu: Union[int, None] = None, wait_time: float = 0.02
232 | self.results_loop_thread = threading.Thread(target=self.results_loop_fn, args=(
233 | self._queue, self.results_loop_queue, self.abort_event, self.pin_memory, self._processes, gpu,
234 | self.wait_time)
235 | )
236 | self.results_loop_thread.daemon = True
237 | self.results_loop_thread.start()
238 |
239 | self.initialized = True
240 | else:
241 | logging.debug("MultiThreadedGenerator Warning: start() has been called but workers are already running")
242 |
243 | def _finish(self):
244 | if self.initialized:
245 | self.abort_event.set()
246 | sleep(self.wait_time)
247 | [i.terminate() for i in self._processes if i.is_alive()]
248 |
249 | del self._queue, self.results_loop_queue, self.results_loop_thread, self.abort_event, self._processes
250 | self._queue, self.results_loop_queue, self.results_loop_thread, self.abort_event = None, None, None, None
251 | self._processes = []
252 | self.initialized = False
253 |
254 | def restart(self):
255 | self._finish()
256 | self._start()
257 |
258 | def __del__(self):
259 | logging.debug("MultiThreadedGenerator: destructor was called")
260 | self._finish()
261 |
262 |
263 | if __name__ == '__main__':
264 | from tests.test_DataLoader import DummyDataLoader
265 | dl = DummyDataLoader(deepcopy(list(range(1234))), 2, 3, None,
266 | return_incomplete=False, shuffle=True,
267 | infinite=True)
268 |
269 | mt = NonDetMultiThreadedAugmenter(dl, None, 3, 2, None, False, 0.02)
270 | mt._start()
271 |
272 | st = time()
273 | for i in range(1000):
274 | print(i)
275 | b = next(mt)
276 | end = time()
277 | print(end - st)
278 |
279 | mt._finish()
--------------------------------------------------------------------------------
/tests/test_DataLoader.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import unittest
17 | from copy import deepcopy
18 | import numpy as np
19 | from batchgenerators.dataloading.data_loader import DataLoader
20 | from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
21 |
22 |
23 | class DummyDataLoader(DataLoader):
24 | def __init__(self, data, batch_size, num_threads_in_multithreaded, seed_for_shuffle=1, return_incomplete=False,
25 | shuffle=True, infinite=False):
26 | super(DummyDataLoader, self).__init__(data, batch_size, num_threads_in_multithreaded, seed_for_shuffle, return_incomplete, shuffle,
27 | infinite)
28 | self.indices = data
29 |
30 | def generate_train_batch(self):
31 | idx = self.get_indices()
32 | return idx
33 |
34 |
35 | class TestDataLoader(unittest.TestCase):
36 | def test_return_all_indices_single_threaded_shuffle_False(self):
37 | data = list(range(123))
38 | batch_sizes = [1, 3, 75, 12, 23]
39 |
40 | for b in batch_sizes:
41 | dl = DummyDataLoader(deepcopy(data), b, 1, 1, return_incomplete=True, shuffle=False, infinite=False)
42 |
43 | for _ in range(3):
44 | idx = []
45 | for i in dl:
46 | idx += i
47 |
48 | self.assertTrue(len(idx) == len(data))
49 | self.assertTrue(all([i == j for i,j in zip(idx, data)]))
50 |
51 | def test_return_all_indices_single_threaded_shuffle_True(self):
52 | data = list(range(123))
53 | batch_sizes = [1, 3, 75, 12, 23]
54 | np.random.seed(1234)
55 |
56 | for b in batch_sizes:
57 | dl = DummyDataLoader(deepcopy(data), b, 1, 1, return_incomplete=True, shuffle=True, infinite=False)
58 |
59 | for _ in range(3):
60 | idx = []
61 | for i in dl:
62 | idx += i
63 |
64 | self.assertTrue(len(idx) == len(data))
65 |
66 | self.assertTrue(not all([i == j for i, j in zip(idx, data)]))
67 |
68 | idx.sort()
69 | self.assertTrue(all([i == j for i,j in zip(idx, data)]))
70 |
71 | def test_infinite_single_threaded(self):
72 | data = list(range(123))
73 |
74 | dl = DummyDataLoader(deepcopy(data), 12, 1, 1, return_incomplete=True, shuffle=True, infinite=False)
75 | # this should raise a StopIteration
76 | with self.assertRaises(StopIteration):
77 | for i in range(1000):
78 | idx = next(dl)
79 |
80 | dl = DummyDataLoader(deepcopy(data), 12, 1, 1, return_incomplete=True, shuffle=True, infinite=True)
81 | # this should now not raise a StopIteration anymore
82 | for i in range(1000):
83 | idx = next(dl)
84 |
85 | def test_return_incomplete_single_threaded(self):
86 | data = list(range(123))
87 | batch_size = 12
88 |
89 | dl = DummyDataLoader(deepcopy(data), batch_size, 1, 1, return_incomplete=False, shuffle=False, infinite=False)
90 | # this should now not raise a StopIteration anymore
91 | total = 0
92 | ctr = 0
93 | for i in dl:
94 | ctr += 1
95 | assert len(i) == batch_size
96 | total += batch_size
97 |
98 | self.assertTrue(total == 120)
99 | self.assertTrue(ctr == 10)
100 |
101 | dl = DummyDataLoader(deepcopy(data), batch_size, 1, 1, return_incomplete=True, shuffle=False, infinite=False)
102 | # this should now not raise a StopIteration anymore
103 | total = 0
104 | ctr = 0
105 | for i in dl:
106 | ctr += 1
107 | total += len(i)
108 |
109 | self.assertTrue(total == 123)
110 | self.assertTrue(ctr == 11)
111 |
112 | def test_return_all_indices_multi_threaded_shuffle_False(self):
113 | data = list(range(123))
114 | batch_sizes = [1, 3, 75, 12, 23]
115 | num_workers = 3
116 |
117 | for b in batch_sizes:
118 | dl = DummyDataLoader(deepcopy(data), b, num_workers, 1, return_incomplete=True, shuffle=False, infinite=False)
119 | mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False)
120 |
121 | for _ in range(3):
122 | idx = []
123 | for i in mt:
124 | idx += i
125 |
126 | self.assertTrue(len(idx) == len(data))
127 | self.assertTrue(all([i == j for i,j in zip(idx, data)]))
128 |
129 | def test_return_all_indices_multi_threaded_shuffle_True(self):
130 | data = list(range(123))
131 | batch_sizes = [1, 3, 75, 12, 23]
132 | num_workers = 3
133 |
134 | for b in batch_sizes:
135 | dl = DummyDataLoader(deepcopy(data), b, num_workers, 1, return_incomplete=True, shuffle=True, infinite=False)
136 | mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False)
137 |
138 | for _ in range(3):
139 | idx = []
140 | for i in mt:
141 | idx += i
142 |
143 | self.assertTrue(len(idx) == len(data))
144 |
145 | self.assertTrue(not all([i == j for i, j in zip(idx, data)]))
146 |
147 | idx.sort()
148 | self.assertTrue(all([i == j for i,j in zip(idx, data)]))
149 |
150 | def test_infinite_multi_threaded(self):
151 | data = list(range(123))
152 | num_workers = 3
153 |
154 | dl = DummyDataLoader(deepcopy(data), 12, num_workers, 1, return_incomplete=True, shuffle=True, infinite=False)
155 | mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False)
156 |
157 | # this should raise a StopIteration
158 | with self.assertRaises(StopIteration):
159 | for i in range(1000):
160 | idx = next(mt)
161 |
162 | dl = DummyDataLoader(deepcopy(data), 12, num_workers, 1, return_incomplete=True, shuffle=True, infinite=True)
163 | mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False)
164 | # this should now not raise a StopIteration anymore
165 | for i in range(1000):
166 | idx = next(mt)
167 |
168 | def test_return_incomplete_multi_threaded(self):
169 | data = list(range(123))
170 | batch_size = 12
171 | num_workers = 3
172 |
173 | dl = DummyDataLoader(deepcopy(data), batch_size, num_workers, 1, return_incomplete=False, shuffle=False, infinite=False)
174 | mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False)
175 | all_return = []
176 | total = 0
177 | ctr = 0
178 | for i in mt:
179 | ctr += 1
180 | assert len(i) == batch_size
181 | total += len(i)
182 | all_return += i
183 |
184 | self.assertTrue(total == 120)
185 | self.assertTrue(ctr == 10)
186 | self.assertTrue(len(np.unique(all_return)) == total)
187 |
188 | dl = DummyDataLoader(deepcopy(data), batch_size, num_workers, 1, return_incomplete=True, shuffle=False, infinite=False)
189 | mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False)
190 | all_return = []
191 | total = 0
192 | ctr = 0
193 | for i in mt:
194 | ctr += 1
195 | total += len(i)
196 | all_return += i
197 |
198 | self.assertTrue(total == 123)
199 | self.assertTrue(ctr == 11)
200 | self.assertTrue(len(np.unique(all_return)) == len(data))
201 |
202 | def test_thoroughly(self):
203 | data_list = [list(range(123)),
204 | list(range(1243)),
205 | list(range(1)),
206 | list(range(7)),
207 | ]
208 | worker_list = (1, 4, 7)
209 | batch_size_list = (1, 3, 32)
210 | seed_list = [318, None]
211 | epochs = 3
212 |
213 | for data in data_list:
214 | #print('data', len(data))
215 | for num_workers in worker_list:
216 | #print('num_workers', num_workers)
217 | for batch_size in batch_size_list:
218 | #print('batch_size', batch_size)
219 | for return_incomplete in [True, False]:
220 | #print('return_incomplete', return_incomplete)
221 | for shuffle in [True, False]:
222 | #print('shuffle', shuffle)
223 | for seed_for_shuffle in seed_list:
224 | #print('seed_for_shuffle', seed_for_shuffle)
225 | if return_incomplete:
226 | if len(data) % batch_size == 0:
227 | expected_num_batches = len(data) // batch_size
228 | else:
229 | expected_num_batches = len(data) // batch_size + 1
230 | else:
231 | expected_num_batches = len(data) // batch_size
232 |
233 | expected_num_items = len(data) if return_incomplete else expected_num_batches * batch_size
234 |
235 | print("init")
236 | dl = DummyDataLoader(deepcopy(data), batch_size, num_workers, seed_for_shuffle,
237 | return_incomplete=return_incomplete, shuffle=shuffle,
238 | infinite=False)
239 |
240 | mt = MultiThreadedAugmenter(dl, None, num_workers, 5, None, False, wait_time=0)
241 | mt._start()
242 |
243 | for epoch in range(epochs):
244 | print("sampling")
245 | all_return = []
246 | total = 0
247 | ctr = 0
248 | for i in mt:
249 | ctr += 1
250 | total += len(i)
251 | all_return += i
252 |
253 | print('asserting')
254 | self.assertTrue(total == expected_num_items)
255 | self.assertTrue(ctr == expected_num_batches)
256 | self.assertTrue(len(np.unique(all_return)) == expected_num_items)
257 |
258 |
259 | if __name__ == "__main__":
260 | from multiprocessing import freeze_support
261 | freeze_support()
262 | unittest.main()
263 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [2021] [Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
190 | AND Applied Computer Vision Lab, Helmholtz Imaging Platform]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/batchgenerators/dataloading/data_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from abc import ABCMeta, abstractmethod
17 | from builtins import object
18 | import warnings
19 | from collections import OrderedDict
20 | from warnings import warn
21 | import numpy as np
22 |
23 | from batchgenerators.dataloading.dataset import Dataset
24 |
25 |
26 | class DataLoaderBase(object):
27 | """ Derive from this class and override generate_train_batch. If you don't want to use this you can use any
28 | generator.
29 | You can modify this class however you want. How the data is presented as batch is you responsibility. You can sample
30 | randomly, cycle through the training examples or sample the dtaa according to a specific pattern. Just make sure to
31 | use our default data structure!
32 | {'data':your_batch_of_shape_(b, c, x, y(, z)),
33 | 'seg':your_batch_of_shape_(b, c, x, y(, z)),
34 | 'anything_else1':whatever,
35 | 'anything_else2':whatever2,
36 | ...}
37 |
38 | (seg is optional)
39 |
40 | Args:
41 | data (anything): Your dataset. Stored as member variable self._data
42 |
43 | BATCH_SIZE (int): batch size. Stored as member variable self.BATCH_SIZE
44 |
45 | num_batches (int): How many batches will be generated before raising StopIteration. None=unlimited. Careful
46 | when using MultiThreadedAugmenter: Each process will produce num_batches batches.
47 |
48 | seed (False, None, int): seed to seed the numpy rng with. False = no seeding
49 |
50 | """
51 | def __init__(self, data, BATCH_SIZE, num_batches=None, seed=False):
52 | warnings.simplefilter("once", DeprecationWarning)
53 | warn("This DataLoader will soon be removed. Migrate everything to SlimDataLoaderBase now!", DeprecationWarning)
54 | __metaclass__ = ABCMeta
55 | self._data = data
56 | self.BATCH_SIZE = BATCH_SIZE
57 | if num_batches is not None:
58 | warn("We currently strongly discourage using num_batches != None! That does not seem to work properly")
59 | self._num_batches = num_batches
60 | self._seed = seed
61 | self._was_initialized = False
62 | if self._num_batches is None:
63 | self._num_batches = 1e100
64 | self._batches_generated = 0
65 | self.thread_id = 0
66 |
67 | def reset(self):
68 | if self._seed is not False:
69 | np.random.seed(self._seed)
70 | self._was_initialized = True
71 | self._batches_generated = 0
72 |
73 | def set_thread_id(self, thread_id):
74 | self.thread_id = thread_id
75 |
76 | def __iter__(self):
77 | return self
78 |
79 | def __next__(self):
80 | if not self._was_initialized:
81 | self.reset()
82 | if self._batches_generated >= self._num_batches:
83 | self._was_initialized = False
84 | raise StopIteration
85 | minibatch = self.generate_train_batch()
86 | self._batches_generated += 1
87 | return minibatch
88 |
89 | @abstractmethod
90 | def generate_train_batch(self):
91 | '''override this
92 | Generate your batch from self._data .Make sure you generate the correct batch size (self.BATCH_SIZE)
93 | '''
94 | pass
95 |
96 |
97 | class SlimDataLoaderBase(object):
98 | def __init__(self, data, batch_size, number_of_threads_in_multithreaded=None):
99 | """
100 | Slim version of DataLoaderBase (which is now deprecated). Only provides very simple functionality.
101 |
102 | You must derive from this class to implement your own DataLoader. You must overrive self.generate_train_batch()
103 |
104 | If you use our MultiThreadedAugmenter you will need to also set and use number_of_threads_in_multithreaded. See
105 | multithreaded_dataloading in examples!
106 |
107 | :param data: will be stored in self._data. You can use it to generate your batches in self.generate_train_batch()
108 | :param batch_size: will be stored in self.batch_size for use in self.generate_train_batch()
109 | :param number_of_threads_in_multithreaded: will be stored in self.number_of_threads_in_multithreaded.
110 | None per default. If you wish to iterate over all your training data only once per epoch, you must coordinate
111 | your Dataloaders and you will need this information
112 | """
113 | __metaclass__ = ABCMeta
114 | self.number_of_threads_in_multithreaded = number_of_threads_in_multithreaded
115 | self._data = data
116 | self.batch_size = batch_size
117 | self.thread_id = 0
118 |
119 | def set_thread_id(self, thread_id):
120 | self.thread_id = thread_id
121 |
122 | def __iter__(self):
123 | return self
124 |
125 | def __next__(self):
126 | return self.generate_train_batch()
127 |
128 | @abstractmethod
129 | def generate_train_batch(self):
130 | '''override this
131 | Generate your batch from self._data .Make sure you generate the correct batch size (self.BATCH_SIZE)
132 | '''
133 | pass
134 |
135 |
136 | class DataLoader(SlimDataLoaderBase):
137 | def __init__(self, data, batch_size, num_threads_in_multithreaded=1, seed_for_shuffle=None, return_incomplete=False,
138 | shuffle=True, infinite=False, sampling_probabilities=None):
139 | """
140 |
141 | :param data: will be stored in self._data for use in generate_train_batch
142 | :param batch_size: will be used by get_indices to return the correct number of indices
143 | :param num_threads_in_multithreaded: num_threads_in_multithreaded necessary for synchronization of dataloaders
144 | when using multithreaded augmenter
145 | :param seed_for_shuffle: for reproducibility
146 | :param return_incomplete: whether or not to return batches that are incomplete. Only applies is infinite=False.
147 | If your data has len of 34 and your batch size is 32 then there return_incomplete=False will make this loader
148 | return only one batch of shape 32 (omitting 2 of your training examples). If return_incomplete=True a second
149 | batch with batch size 2 will be returned.
150 | :param shuffle: if True, the order of the indices will be shuffled between epochs. Only applies if infinite=False
151 | :param infinite: if True, each batch contains randomly (uniformly) sampled indices. An unlimited number of
152 | batches is returned. If False, DataLoader will iterate over the data only once
153 | :param sampling_probabilities: only applies if infinite=True. If sampling_probabilities is not None, the
154 | probabilities will be used by np.random.choice to sample the indexes for each batch. Important:
155 | sampling_probabilities must have as many entries as there are samples in your dataset AND
156 | sampling_probabilitiesneeds to sum to 1
157 | """
158 | super(DataLoader, self).__init__(data, batch_size, num_threads_in_multithreaded)
159 | self.infinite = infinite
160 | self.shuffle = shuffle
161 | self.return_incomplete = return_incomplete
162 | self.seed_for_shuffle = seed_for_shuffle
163 | self.rs = np.random.RandomState(self.seed_for_shuffle)
164 | self.current_position = None
165 | self.was_initialized = False
166 | self.last_reached = False
167 | self.sampling_probabilities = sampling_probabilities
168 |
169 | # when you derive, make sure to set this! We can't set it here because we don't know what data will be like
170 | self.indices = None
171 |
172 | def reset(self):
173 | assert self.indices is not None
174 |
175 | self.current_position = self.thread_id * self.batch_size
176 |
177 | self.was_initialized = True
178 |
179 | # no need to shuffle if we are returning infinite random samples
180 | if not self.infinite and self.shuffle:
181 | self.rs.shuffle(self.indices)
182 |
183 | self.last_reached = False
184 |
185 | def get_indices(self):
186 | # if self.infinite, this is easy
187 | if self.infinite:
188 | return np.random.choice(self.indices, self.batch_size, replace=True, p=self.sampling_probabilities)
189 |
190 | if self.last_reached:
191 | self.reset()
192 | raise StopIteration
193 |
194 | if not self.was_initialized:
195 | self.reset()
196 |
197 | indices = []
198 |
199 | for b in range(self.batch_size):
200 | if self.current_position < len(self.indices):
201 | indices.append(self.indices[self.current_position])
202 |
203 | self.current_position += 1
204 | else:
205 | self.last_reached = True
206 | break
207 |
208 | if len(indices) > 0 and ((not self.last_reached) or self.return_incomplete):
209 | self.current_position += (self.number_of_threads_in_multithreaded - 1) * self.batch_size
210 | return indices
211 | else:
212 | self.reset()
213 | raise StopIteration
214 |
215 | @abstractmethod
216 | def generate_train_batch(self):
217 | '''
218 | make use of self.get_indices() to know what indices to work on!
219 | :return:
220 | '''
221 | pass
222 |
223 |
224 | def default_collate(batch):
225 | '''
226 | heavily inspired by the default_collate function of pytorch
227 | :param batch:
228 | :return:
229 | '''
230 | if isinstance(batch[0], np.ndarray):
231 | return np.vstack(batch)
232 | elif isinstance(batch[0], (int, np.int64)):
233 | return np.array(batch).astype(np.int32)
234 | elif isinstance(batch[0], (float, np.float32)):
235 | return np.array(batch).astype(np.float32)
236 | elif isinstance(batch[0], (np.float64,)):
237 | return np.array(batch).astype(np.float64)
238 | elif isinstance(batch[0], (dict, OrderedDict)):
239 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
240 | elif isinstance(batch[0], (tuple, list)):
241 | transposed = zip(*batch)
242 | return [default_collate(samples) for samples in transposed]
243 | elif isinstance(batch[0], str):
244 | return batch
245 | else:
246 | raise TypeError('unknown type for batch:', type(batch))
247 |
248 |
249 | class DataLoaderFromDataset(DataLoader):
250 | def __init__(self, data, batch_size, num_threads_in_multithreaded, seed_for_shuffle=1, collate_fn=default_collate,
251 | return_incomplete=False, shuffle=True, infinite=False):
252 | '''
253 | A simple dataloader that can take a Dataset as data.
254 | It is not super efficient because I cannot make too many hard assumptions about what data_dict will contain.
255 | If you know what you need, implement your own!
256 | :param data:
257 | :param batch_size:
258 | :param num_threads_in_multithreaded:
259 | :param seed_for_shuffle:
260 | '''
261 | super(DataLoaderFromDataset, self).__init__(data, batch_size, num_threads_in_multithreaded, seed_for_shuffle,
262 | return_incomplete=return_incomplete, shuffle=shuffle,
263 | infinite=infinite)
264 | self.collate_fn = collate_fn
265 | assert isinstance(self._data, Dataset)
266 | self.indices = np.arange(len(data))
267 |
268 | def generate_train_batch(self):
269 | indices = self.get_indices()
270 |
271 | batch = [self._data[i] for i in indices]
272 |
273 | return self.collate_fn(batch)
274 |
--------------------------------------------------------------------------------