├── tests ├── __init__.py └── transform │ ├── __init__.py │ ├── utils │ ├── __init__.py │ └── boolean_tree_test.py │ └── sequences │ ├── __init__.py │ ├── functional_test.py │ └── sequence_transformers_test.py ├── examples ├── __init__.py ├── example.gif └── make_gifs.py ├── transform ├── __init__.py ├── utils │ ├── __init__.py │ ├── utils.py │ ├── boolean_tree.py │ └── transformations.py └── sequences │ ├── __init__.py │ ├── functional.py │ └── sequence_transformers.py ├── pytest.ini ├── setup.py ├── LICENSE ├── .gitignore ├── .circleci └── config.yml └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transform/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/transform/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/transform/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/transform/sequences/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transform/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .boolean_tree import * 2 | from .utils import * 3 | -------------------------------------------------------------------------------- /examples/example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dref360/keras-transform/HEAD/examples/example.gif -------------------------------------------------------------------------------- /transform/sequences/__init__.py: -------------------------------------------------------------------------------- 1 | from .sequence_transformers import * 2 | from .functional import * 3 | -------------------------------------------------------------------------------- /transform/utils/utils.py: -------------------------------------------------------------------------------- 1 | def get_batch_size(batch): 2 | """Get the batch size from a tree structure.""" 3 | if isinstance(batch, (list, tuple)): 4 | return get_batch_size(batch[0]) 5 | return batch.shape[0] 6 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | # Configuration of py.test 2 | [pytest] 3 | addopts=-v 4 | --durations=10 5 | 6 | # Do not run tests in the build folder 7 | norecursedirs= build 8 | 9 | # PEP-8 The following are ignored: 10 | # E501 line too long (82 > 79 characters) 11 | # E402 module level import not at top of file - temporary measure to continue adding ros python packaged in sys.path 12 | # E731 do not assign a lambda expression, use a def 13 | 14 | pep8ignore=* E501 \ 15 | * E402 \ 16 | * E731 \ 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from setuptools import find_packages 3 | 4 | setup(name='keras-transform', 5 | version='0.1.1', 6 | description='Library for data augmentation', 7 | author='Frederic Branchaud-Charron', 8 | author_email='frederic.branchaud-charron@usherbrooke.ca', 9 | url='https://github.com/Dref360/keras-transform', 10 | license='MIT', 11 | install_requires=['numpy', 'theano', 'keras>=2.2.0'], 12 | extras_require={ 13 | 'tests': ['pytest', 14 | 'pytest-pep8', 15 | 'pytest-xdist', 16 | 'pytest-cov'], 17 | }, 18 | packages=find_packages()) 19 | -------------------------------------------------------------------------------- /transform/sequences/functional.py: -------------------------------------------------------------------------------- 1 | class SequentialTransformer(): 2 | def __init__(self, transformers): 3 | """ 4 | Combine multiple transformers. 5 | :param transformers: List of SequenceTransformers 6 | """ 7 | self.transformers = transformers 8 | 9 | def __call__(self, seq, mask=(True, False)): 10 | """ 11 | Create a transformer that combines multiples transformers. 12 | :param seq: Sequence object 13 | :param mask: Boolean tree-like structure. 14 | :return: Sequence 15 | """ 16 | for transformer in self.transformers: 17 | seq = transformer(seq, mask) 18 | return seq 19 | -------------------------------------------------------------------------------- /transform/utils/boolean_tree.py: -------------------------------------------------------------------------------- 1 | def get_value(tree, idx): 2 | if not idx or not isinstance(tree, (list, tuple)): 3 | return tree 4 | elif len(idx) == 1: 5 | return tree[idx[0]] 6 | else: 7 | return get_value(tree[idx[0]], idx[1:]) 8 | 9 | 10 | def handle_mask(mask, tree): 11 | """Expand the mask to match the tree structure. 12 | :param mask: boolean mask 13 | :param tree: tree structure 14 | :return: boolean mask 15 | """ 16 | if isinstance(mask, bool): 17 | return [mask] * len(tree) 18 | return mask 19 | 20 | 21 | def apply_fun(tree, fun, mask, **kwargs): 22 | """Apply a function recursively on a list. 23 | :param tree: Tree structure of lists 24 | :param fun: function to apply 25 | :param mask: boolean mask to control the application of `fun`. 26 | :param kwargs: arguments for `fun` 27 | :return: list 28 | """ 29 | if not isinstance(tree, (list, tuple)): 30 | return fun(tree, **kwargs) if mask else tree 31 | else: 32 | return [apply_fun(tr, fun, ma, **kwargs) if ma else tr for tr, ma in zip(tree, handle_mask(mask, tree))] 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Frédéric Branchaud-Charron 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | # Python CircleCI 2.0 configuration file 2 | # 3 | # Check https://circleci.com/docs/2.0/language-python/ for more details 4 | # 5 | version: 2 6 | jobs: 7 | build: 8 | docker: 9 | # specify the version you desire here 10 | # use `-browsers` prefix for selenium tests, e.g. `3.6.1-browsers` 11 | - image: circleci/python:3.6.1 12 | 13 | # Specify service dependencies here if necessary 14 | # CircleCI maintains a library of pre-built images 15 | # documented at https://circleci.com/docs/2.0/circleci-images/ 16 | # - image: circleci/postgres:9.4 17 | 18 | working_directory: ~/keras-transform 19 | 20 | steps: 21 | - checkout 22 | 23 | # Download and cache dependencies 24 | - restore_cache: 25 | keys: 26 | - v1-dependencies-{{ checksum "setup.py" }} 27 | # fallback to using the latest cache if no exact match is found 28 | - v1-dependencies- 29 | 30 | - run: 31 | name: install dependencies 32 | command: | 33 | python3 -m venv venv 34 | . venv/bin/activate 35 | pip3 install numpy theano tensorflow 36 | pip3 install keras_preprocessing keras>=2.2.0 --no-deps 37 | pip3 install -e .[tests] 38 | 39 | - save_cache: 40 | paths: 41 | - ./venv 42 | key: v1-dependencies-{{ checksum "setup.py" }} 43 | 44 | # run tests! 45 | - run: 46 | name: run tests 47 | command: | 48 | . venv/bin/activate 49 | pytest tests/ 50 | pytest --pep8 -m pep8 -n0 51 | 52 | - store_artifacts: 53 | path: test-reports 54 | destination: test-reports 55 | -------------------------------------------------------------------------------- /examples/make_gifs.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from keras.utils import Sequence 4 | 5 | from transform.sequences import RandomRotationTransformer, RandomHorizontalFlipTransformer, RandomShearTransformer, \ 6 | RandomZoomTransformer 7 | 8 | """First, let's create a simple Sequence that load an image and resize it.""" 9 | 10 | 11 | class SimpleSequence(Sequence): 12 | def __init__(self, paths, shape=(200, 200)): 13 | self.paths = paths 14 | self.shape = shape 15 | self.batch_size = 1 16 | 17 | def __len__(self): 18 | return len(self.paths) // self.batch_size 19 | 20 | def __getitem__(self, index): 21 | paths = self.paths[index * self.batch_size:(index + 1) * self.batch_size] 22 | X = [cv2.resize(cv2.imread(p), self.shape) for p in paths] 23 | y = [cv2.cvtColor(x, cv2.COLOR_BGR2GRAY) for x in X] 24 | return np.array(X), np.array(y).reshape([self.batch_size, self.shape[0], self.shape[1], 1]) 25 | 26 | 27 | """Transformers are Sequence that takes a Sequence to modify it.""" 28 | from glob import glob 29 | 30 | paths = glob('/data/images_folder/*.jpg') 31 | seq = SimpleSequence(paths) 32 | 33 | """Applying the SAME transformation to X and y is done by specifying a mask.""" 34 | transformer = RandomRotationTransformer(10)(seq, mask=[True, True]) 35 | transformer = RandomHorizontalFlipTransformer()(transformer, mask=[True, True]) 36 | transformer = RandomShearTransformer(intensity=0.5)(transformer, mask=[True, True]) 37 | transformer = RandomZoomTransformer(zoom_range=(0.8, 1.2))(transformer, mask=[True, True]) 38 | 39 | # 200,400 40 | vid = cv2.VideoWriter(filename='/data/output.avi', fourcc=cv2.VideoWriter_fourcc(*'MJPG'), fps=5, frameSize=(400, 200), 41 | isColor=True) 42 | 43 | try: 44 | for i in range(100): 45 | X, y = transformer[0] 46 | im = np.concatenate((X[0], cv2.cvtColor(y[0], cv2.COLOR_GRAY2BGR)), 1) 47 | vid.write(im) 48 | cv2.imshow('Test', im) 49 | cv2.waitKey(100) 50 | 51 | except: 52 | pass 53 | vid.release() 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # keras-transform 2 | Library for data augmentation 3 | 4 | *ANNOUNCEMENT* : I won't really work on this library anymore, the recent changes made with Keras 2.2.0 made this library obselete. Please see my blog post : https://dref360.github.io/deterministic-da/ 5 | 6 | This library provides a data augmentation pipeline for `Sequence` objects. 7 | 8 | **Keras-transform** allows the user to specify a mask to do data augmentation in a flexible way. This is useful in many tasks like segmentation where we want the ground truth to be augmented. 9 | See [simple.ipynb](examples/simple.ipynb). 10 | 11 | **Keras-transform** also works with multiple inputs, outputs by using complex masks. 12 | For example, `mask=[[True,False],False]` would augment the first input but not the second. 13 | 14 | ## keras-transform in 10 lines 15 | 16 | ```python 17 | from transform.sequences import SequentialTransformer 18 | from transform.sequences import RandomZoomTransformer, RandomVerticalFlipTransformer 19 | 20 | seq = ... # A keras.utils.Sequence object that returns a tuple (X,y) 21 | model = ... # A keras Model 22 | 23 | """ 24 | A transformer transforms the input. Most data augmentation functions are implemented in transform.sequences. 25 | We can chain transformers together using the SequentialTransformer that takes a list of transformers. 26 | """ 27 | sequence = SequentialTransformer([RandomZoomTransformer(zoom_range=(0.8,1.2)), 28 | RandomVerticalFlipTransformer()]) 29 | 30 | # To augment X but not y 31 | augmented_sequence = sequence(seq,mask=[True,False]) 32 | model.fit_generator(augmented_sequence,steps_per_epoch=len(augmented_sequence)) 33 | 34 | # To augment X and y 35 | augmented_sequence = sequence(seq,mask=[True,True]) # Alternatively, mask=True would also work. 36 | model.fit_generator(augmented_sequence,steps_per_epoch=len(augmented_sequence)) 37 | 38 | ``` 39 | 40 | 41 | 42 | # Contributing 43 | Anyone can contribute by submitting a PR. 44 | Any PR that adds a new feature needs to be tested. 45 | 46 | # Example 47 | 48 | Here's an example where X is an image and the ground truth is the grayscale version of the input. The code can be found [here](examples/make_gifs.py). 49 | 50 | ![alt-text](/examples/example.gif) 51 | 52 | 53 | -------------------------------------------------------------------------------- /tests/transform/sequences/functional_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from keras.utils import Sequence 4 | 5 | from transform.sequences import SequentialTransformer, RandomZoomTransformer, RandomVerticalFlipTransformer 6 | 7 | 8 | class TestSequence(Sequence): 9 | """Create a X,Y tuple""" 10 | 11 | def __getitem__(self, index): 12 | return np.arange(5 * 20 * 20 * 3).reshape([5, 20, 20, 3]), np.arange(5 * 20 * 20 * 3).reshape([5, 20, 20, 3]) 13 | 14 | def __len__(self): 15 | return 10 16 | 17 | 18 | class TestTreeSequence(Sequence): 19 | """Create a [X1,X2],Y1 tuple.""" 20 | 21 | def __getitem__(self, index): 22 | return [np.arange(5 * 20 * 20 * 3).reshape([5, 20, 20, 3]), 23 | np.arange(5 * 12 * 12 * 3).reshape([5, 12, 12, 3])], np.arange( 24 | 5 * 10 * 10 * 3).reshape([5, 10, 10, 3]) 25 | 26 | def __len__(self): 27 | return 10 28 | 29 | 30 | def inner_transformer(transformer_obj, **kwargs): 31 | transformer = transformer_obj(TestSequence()) 32 | # Assert that X changes between 2 calls and Y does not. 33 | assert np.any(np.not_equal(transformer[0][0], transformer[1][0])) and np.all( 34 | np.equal(transformer[0][1], transformer[1][1])) 35 | 36 | transformer = transformer_obj(TestTreeSequence()) 37 | 38 | assert all([np.any(np.not_equal(t0, t1)) for t0, t1 in zip(transformer[0][0], transformer[1][0])]) and all( 39 | [np.all(np.equal(t0, t1)) for t0, t1 in zip(transformer[0][1], transformer[1][1])]) 40 | 41 | # Test Mask 42 | transformer = transformer_obj(TestTreeSequence(), mask=False) 43 | 44 | assert all([np.any(np.equal(t0, t1)) for t0, t1 in zip(transformer[0][0], transformer[1][0])]) and np.equal( 45 | transformer[0][1], transformer[1][1]).all() 46 | 47 | transformer = transformer_obj(TestTreeSequence(), mask=[True, True]) 48 | 49 | assert all( 50 | [np.any(np.not_equal(t0, t1)) for t0, t1 in zip(transformer[0][0], transformer[1][0])]) and np.not_equal( 51 | transformer[0][1], transformer[1][1]).any() 52 | 53 | # Should transform the same way for X and y 54 | transformer = transformer_obj(TestSequence(), mask=[True, True]) 55 | assert (np.equal(*transformer[0])).all() 56 | 57 | # Common case where we augment X but not y 58 | transformer = transformer_obj(TestSequence(), mask=[True, False]) 59 | assert (np.not_equal(*transformer[0])).any() 60 | 61 | 62 | def test_sequential(): 63 | # TODO need better test. 64 | sequential = SequentialTransformer([RandomZoomTransformer((0.8, 1.2)), 65 | RandomVerticalFlipTransformer()]) 66 | 67 | inner_transformer(sequential) 68 | 69 | 70 | if __name__ == '__main__': 71 | pytest.main([__file__]) 72 | -------------------------------------------------------------------------------- /tests/transform/utils/boolean_tree_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from numpy.testing import assert_almost_equal 4 | 5 | from transform.utils import get_value, handle_mask, apply_fun 6 | 7 | 8 | def is_same(arr1, arr2): 9 | """Recursively check if 2 lists of array are equal.""" 10 | if isinstance(arr1, (list, tuple)): 11 | return all([is_same(a1, a2) for a1, a2 in zip(arr1, arr2)]) 12 | return np.allclose(arr1, arr2) 13 | 14 | 15 | def test_get_value(): 16 | # Always true 17 | tree = True 18 | assert get_value(tree, [0]) is True 19 | tree = [True] 20 | assert get_value(tree, [0]) is True 21 | # Common pattern where x is transformed, not y 22 | tree = [True, False] 23 | assert get_value(tree, [0]) is True 24 | assert get_value(tree, [1]) is False 25 | 26 | # Mixed inputs 27 | tree = [[True, False], False] 28 | assert get_value(tree, [0, 0]) is True 29 | assert get_value(tree, [0, 1]) is False 30 | assert get_value(tree, [1]) is False 31 | 32 | # Mixed output 33 | tree = [[True, False], [False, True]] 34 | assert get_value(tree, [1, 0]) is False 35 | assert get_value(tree, [1, 1]) is True 36 | 37 | 38 | def test_handle_mask(): 39 | # Always true 40 | mask = True 41 | assert handle_mask(mask, [0]) == [True] 42 | mask = [True] 43 | assert handle_mask(mask, [0]) == [True] 44 | # Common pattern where x is transformed, not y 45 | mask = [True, False] 46 | assert handle_mask(mask, [0, 1]) == [True, False] 47 | 48 | 49 | def test_apply_fun(): 50 | def fun(x): 51 | return x * 0.0 52 | 53 | inp = np.ones([10]) 54 | out = np.zeros([10]) 55 | 56 | assert_almost_equal(apply_fun(inp.copy(), fun, True), out) 57 | assert_almost_equal(apply_fun(inp.copy(), fun, False), inp) 58 | 59 | assert is_same(apply_fun([inp.copy()], fun, True), [out]) 60 | 61 | assert is_same(apply_fun([inp.copy()], fun, True), [out]) 62 | assert is_same(apply_fun([inp.copy()], fun, False), [inp]) 63 | 64 | assert is_same(apply_fun([inp.copy(), inp.copy()], fun, True), [out, out]) 65 | assert is_same(apply_fun([inp.copy(), inp.copy()], fun, False), [inp, inp]) 66 | 67 | assert is_same(apply_fun([inp.copy(), inp.copy()], fun, [True, False]), [out, inp]) 68 | assert is_same(apply_fun([inp.copy(), inp.copy()], fun, [False, True]), [inp, out]) 69 | 70 | assert is_same(apply_fun([[inp.copy(), inp.copy()], inp.copy()], fun, [True, False]), [[out, out], inp]) 71 | assert is_same(apply_fun([[inp.copy(), inp.copy()], inp.copy()], fun, [False, True]), [[inp, inp], out]) 72 | 73 | assert is_same(apply_fun([[inp.copy(), inp.copy()], inp.copy()], fun, [[True, True], False]), [[out, out], inp]) 74 | assert is_same(apply_fun([[inp.copy(), inp.copy()], inp.copy()], fun, [[False, True], True]), [[inp, out], out]) 75 | 76 | assert is_same(apply_fun([[inp.copy(), inp.copy()], inp.copy()], fun, [[False, True], False]), [[inp, out], inp]) 77 | 78 | 79 | if __name__ == '__main__': 80 | pytest.main([__file__]) 81 | -------------------------------------------------------------------------------- /tests/transform/sequences/sequence_transformers_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from keras.utils import Sequence 4 | 5 | from transform.sequences import (RandomRotationTransformer, RandomShiftTransformer, RandomZoomTransformer, 6 | RandomChannelShiftTransformer, RandomShearTransformer, RandomHorizontalFlipTransformer, 7 | RandomVerticalFlipTransformer 8 | ) 9 | 10 | 11 | class TestSequence(Sequence): 12 | """Create a X,Y tuple""" 13 | 14 | def __getitem__(self, index): 15 | return np.arange(5 * 20 * 20 * 3).reshape([5, 20, 20, 3]), np.arange(5 * 20 * 20 * 3).reshape([5, 20, 20, 3]) 16 | 17 | def __len__(self): 18 | return 10 19 | 20 | 21 | class TestTreeSequence(Sequence): 22 | """Create a [X1,X2],Y1 tuple.""" 23 | 24 | def __getitem__(self, index): 25 | return [np.arange(5 * 20 * 20 * 3).reshape([5, 20, 20, 3]), 26 | np.arange(5 * 12 * 12 * 3).reshape([5, 12, 12, 3])], np.arange( 27 | 5 * 10 * 10 * 3).reshape([5, 10, 10, 3]) 28 | 29 | def __len__(self): 30 | return 10 31 | 32 | 33 | def test_random_rot(): 34 | np.random.seed(1337) 35 | inner_transformer(RandomRotationTransformer, rg=25) 36 | 37 | 38 | def test_random_shift(): 39 | np.random.seed(1337) 40 | inner_transformer(RandomShiftTransformer, wrg=0.5, hrg=0.5) 41 | 42 | 43 | def test_random_zoom(): 44 | np.random.seed(1337) 45 | inner_transformer(RandomZoomTransformer, zoom_range=(.2, 1.5)) 46 | 47 | 48 | def test_random_intensity_shift(): 49 | np.random.seed(1337) 50 | inner_transformer(RandomChannelShiftTransformer, intensity=10) 51 | 52 | 53 | def test_random_shear(): 54 | np.random.seed(1337) 55 | inner_transformer(RandomShearTransformer, intensity=10) 56 | 57 | 58 | def test_random_flip(): 59 | np.random.seed(1337) 60 | # This SHOULD work since batch_size is 5 and we have 50% chances of doing a flip. 61 | inner_transformer(RandomHorizontalFlipTransformer) 62 | inner_transformer(RandomVerticalFlipTransformer) 63 | 64 | 65 | def test_assert(): 66 | with pytest.raises(AssertionError): 67 | _ = RandomHorizontalFlipTransformer()[0] 68 | 69 | 70 | def inner_transformer(transformer_cls, **kwargs): 71 | transformer = transformer_cls(**kwargs)(TestSequence()) 72 | # Assert that X changes between 2 calls and Y does not. 73 | assert np.any(np.not_equal(transformer[0][0], transformer[1][0])) and np.all( 74 | np.equal(transformer[0][1], transformer[1][1])) 75 | 76 | transformer = transformer_cls(**kwargs)(TestTreeSequence()) 77 | 78 | assert all([np.any(np.not_equal(t0, t1)) for t0, t1 in zip(transformer[0][0], transformer[1][0])]) and all( 79 | [np.all(np.equal(t0, t1)) for t0, t1 in zip(transformer[0][1], transformer[1][1])]) 80 | 81 | # Test Mask 82 | transformer = transformer_cls(**kwargs)(TestTreeSequence(), mask=False) 83 | 84 | assert all([np.any(np.equal(t0, t1)) for t0, t1 in zip(transformer[0][0], transformer[1][0])]) and np.equal( 85 | transformer[0][1], transformer[1][1]).all() 86 | 87 | transformer = transformer_cls(**kwargs)(TestTreeSequence(), mask=[True, True]) 88 | 89 | assert all([np.any(np.not_equal(t0, t1)) for t0, t1 in zip(transformer[0][0], transformer[1][0])]) and np.not_equal( 90 | transformer[0][1], transformer[1][1]).any() 91 | 92 | # Should transform the same way for X and y 93 | transformer = transformer_cls(**kwargs)(TestSequence(), mask=[True, True]) 94 | assert (np.equal(*transformer[0])).all() 95 | 96 | # Common case where we augment X but not y 97 | transformer = transformer_cls(**kwargs)(TestSequence(), mask=[True, False]) 98 | assert (np.not_equal(*transformer[0])).any() 99 | 100 | 101 | if __name__ == '__main__': 102 | pytest.main([__file__]) 103 | -------------------------------------------------------------------------------- /transform/utils/transformations.py: -------------------------------------------------------------------------------- 1 | """ 2 | DISCLAIMER : This code has been heavily borrowed from the Keras source code. 3 | Fairly basic set of tools for real-time data augmentation on image data. 4 | Can easily be extended to include new transformations, 5 | new preprocessing methods, etc... 6 | """ 7 | from __future__ import absolute_import 8 | from __future__ import print_function 9 | 10 | import numpy as np 11 | from keras.preprocessing.image import apply_affine_transform 12 | from keras_preprocessing.image import flip_axis 13 | 14 | try: 15 | from PIL import Image as pil_image 16 | except ImportError: 17 | pil_image = None 18 | 19 | 20 | def random_rotation(x, rg, row_axis=1, col_axis=2, channel_axis=0, 21 | fill_mode='nearest', cval=0., theta=None): 22 | """Performs a random rotation of a Numpy image tensor. 23 | 24 | # Arguments 25 | x: Input tensor. Must be 3D. 26 | rg: Rotation range, in degrees. 27 | row_axis: Index of axis for rows in the input tensor. 28 | col_axis: Index of axis for columns in the input tensor. 29 | channel_axis: Index of axis for channels in the input tensor. 30 | fill_mode: Points outside the boundaries of the input 31 | are filled according to the given mode 32 | (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 33 | cval: Value used for points outside the boundaries 34 | of the input if `mode='constant'`. 35 | theta: Value to disable randomness or None. 36 | 37 | # Returns 38 | Rotated Numpy image tensor. 39 | """ 40 | theta = np.pi / 180 * np.random.uniform(-rg, rg) if theta is None else theta 41 | x = apply_affine_transform(x, theta=theta, channel_axis=channel_axis, fill_mode=fill_mode, cval=cval, row_axis=row_axis, col_axis=col_axis) 42 | return x 43 | 44 | 45 | def random_shift(x, wrg, hrg, row_axis=1, col_axis=2, channel_axis=0, 46 | fill_mode='nearest', cval=0., tx=None, ty=None): 47 | """Performs a random spatial shift of a Numpy image tensor. 48 | 49 | # Arguments 50 | x: Input tensor. Must be 3D. 51 | wrg: Width shift range, as a float fraction of the width. 52 | hrg: Height shift range, as a float fraction of the height. 53 | row_axis: Index of axis for rows in the input tensor. 54 | col_axis: Index of axis for columns in the input tensor. 55 | channel_axis: Index of axis for channels in the input tensor. 56 | fill_mode: Points outside the boundaries of the input 57 | are filled according to the given mode 58 | (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 59 | cval: Value used for points outside the boundaries 60 | of the input if `mode='constant'`. 61 | tx : Value to disable randomness in X or None. 62 | ty : Value to disable randomness in Y or None. 63 | 64 | # Returns 65 | Shifted Numpy image tensor. 66 | """ 67 | h, w = x.shape[row_axis], x.shape[col_axis] 68 | tx = np.random.uniform(-hrg, hrg) * h if tx is None else tx 69 | ty = np.random.uniform(-wrg, wrg) * w if ty is None else ty 70 | 71 | tx *= h 72 | ty *= w 73 | x = apply_affine_transform(x, tx=tx, ty=ty, channel_axis=channel_axis, fill_mode=fill_mode, cval=cval) 74 | return x 75 | 76 | 77 | def random_shear(x, intensity, row_axis=1, col_axis=2, channel_axis=0, 78 | fill_mode='nearest', cval=0., known_intensity=None): 79 | """Performs a random spatial shear of a Numpy image tensor. 80 | 81 | # Arguments 82 | x: Input tensor. Must be 3D. 83 | intensity: Transformation intensity. 84 | row_axis: Index of axis for rows in the input tensor. 85 | col_axis: Index of axis for columns in the input tensor. 86 | channel_axis: Index of axis for channels in the input tensor. 87 | fill_mode: Points outside the boundaries of the input 88 | are filled according to the given mode 89 | (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 90 | cval: Value used for points outside the boundaries 91 | of the input if `mode='constant'`. 92 | known_intensity: Value to disable randomness or None. 93 | 94 | # Returns 95 | Sheared Numpy image tensor. 96 | """ 97 | shear = np.random.uniform(-intensity, intensity) if known_intensity is None else known_intensity 98 | x = apply_affine_transform(x, shear=shear, channel_axis=channel_axis, fill_mode=fill_mode, cval=cval, row_axis=row_axis, col_axis=col_axis) 99 | return x 100 | 101 | 102 | def random_zoom(x, zoom_range, row_axis=1, col_axis=2, channel_axis=0, 103 | fill_mode='nearest', cval=0., z_known=None): 104 | """Performs a random spatial zoom of a Numpy image tensor. 105 | 106 | # Arguments 107 | x: Input tensor. Must be 3D. 108 | zoom_range: Tuple of floats; zoom range for width and height. 109 | row_axis: Index of axis for rows in the input tensor. 110 | col_axis: Index of axis for columns in the input tensor. 111 | channel_axis: Index of axis for channels in the input tensor. 112 | fill_mode: Points outside the boundaries of the input 113 | are filled according to the given mode 114 | (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 115 | cval: Value used for points outside the boundaries 116 | of the input if `mode='constant'`. 117 | z_known: Value to disable randomness or None. 118 | 119 | # Returns 120 | Zoomed Numpy image tensor. 121 | 122 | # Raises 123 | ValueError: if `zoom_range` isn't a tuple. 124 | """ 125 | if z_known is None: 126 | if len(zoom_range) != 2: 127 | raise ValueError('`zoom_range` should be a tuple or list of two floats. ' 128 | 'Received arg: ', zoom_range) 129 | if zoom_range[0] == 1 and zoom_range[1] == 1: 130 | zx, zy = 1, 1 131 | else: 132 | zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2) 133 | else: 134 | zx, zy = z_known 135 | x = apply_affine_transform(x, zx=zx, zy=zy, channel_axis=channel_axis, fill_mode=fill_mode, cval=cval, row_axis=row_axis, col_axis=col_axis) 136 | return x 137 | 138 | 139 | def random_channel_shift(x, intensity, channel_axis=0, known_intensity=None): 140 | x = np.rollaxis(x, channel_axis, 0) 141 | known_intensity = np.random.uniform(-intensity, intensity) if known_intensity is None else known_intensity 142 | min_x, max_x = np.min(x), np.max(x) 143 | channel_images = [np.clip(x_channel + known_intensity, min_x, max_x) 144 | for x_channel in x] 145 | x = np.stack(channel_images, axis=0) 146 | x = np.rollaxis(x, 0, channel_axis + 1) 147 | return x 148 | 149 | 150 | def transform_matrix_offset_center(matrix, x, y): 151 | o_x = float(x) / 2 + 0.5 152 | o_y = float(y) / 2 + 0.5 153 | offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]]) 154 | reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]]) 155 | transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix) 156 | return transform_matrix 157 | 158 | 159 | def flip_horizontal(x, value, col_axis=2): 160 | if value < 0.5: 161 | return flip_axis(x, col_axis) 162 | else: 163 | return x 164 | 165 | 166 | def flip_vertical(x, value, row_axis=1): 167 | if value < 0.5: 168 | return flip_axis(x, row_axis) 169 | else: 170 | return x 171 | -------------------------------------------------------------------------------- /transform/sequences/sequence_transformers.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | import numpy as np 3 | from keras.utils import Sequence 4 | 5 | from transform.utils import apply_fun, get_batch_size 6 | from transform.utils.transformations import (random_rotation, random_shift, random_zoom, random_channel_shift, 7 | random_shear, flip_horizontal, flip_vertical) 8 | 9 | 10 | class BaseSequenceTransformer(Sequence): 11 | """Base object for transformers. 12 | 13 | # Arguments 14 | data_format: `'channels_last'`, `'channels_first'` or None 15 | """ 16 | 17 | def __init__(self, data_format=None): 18 | self.sequence = None 19 | self.mask = None 20 | self.batch_size = None # We do not know yet 21 | self.transformation = id 22 | if data_format is None: 23 | data_format = K.image_data_format() 24 | if data_format not in {'channels_last', 'channels_first'}: 25 | raise ValueError('`data_format` should be `"channels_last"` (channel after row and ' 26 | 'column) or `"channels_first"` (channel before row and column). ' 27 | 'Received arg: ', data_format) 28 | self.data_format = data_format 29 | if data_format == 'channels_first': 30 | self.channel_axis = 1 31 | self.row_axis = 2 32 | self.col_axis = 3 33 | if data_format == 'channels_last': 34 | self.channel_axis = 3 35 | self.row_axis = 1 36 | self.col_axis = 2 37 | 38 | self.common_args = {'row_axis': self.row_axis, 'col_axis': self.col_axis, 39 | 'channel_axis': self.channel_axis - 1} 40 | 41 | def __call__(self, seq, mask=(True, False)): 42 | self.mask = mask 43 | self.sequence = seq 44 | return self 45 | 46 | def on_epoch_end(self): 47 | pass 48 | 49 | def apply_transformation(self, x_, transformation, args): 50 | """ 51 | Apply the `transformation` to the input `x_`. 52 | :param x_: np.array, the input 53 | :param transformation: function to apply 54 | :param args: dict, arguments for `transformation` 55 | :return: np.array 56 | """ 57 | return np.asarray( 58 | list(map(lambda args: transformation(args[0], **args[1]), 59 | zip(x_, args)))) 60 | 61 | def get_args(self): 62 | """Retrieve args to provide to the transformer. The args should not be aware of the input dimension. 63 | 64 | # Returns 65 | A list of batch_size args. 66 | """ 67 | raise NotImplementedError 68 | 69 | def __getitem__(self, index): 70 | assert self.sequence, "This transformer {} has not been called with a Sequence object".format( 71 | self.__class__.__name__) 72 | batch = self.sequence[index] 73 | if self.batch_size is None: 74 | # The first batch should be the maximum batch_size i.e. not the last. 75 | self.batch_size = get_batch_size(batch) 76 | 77 | args = self.get_args() 78 | for arg in args: 79 | arg.update(self.common_args) 80 | 81 | return apply_fun(batch, self.apply_transformation, self.mask, transformation=self.transformation, args=args) 82 | 83 | def __len__(self): 84 | return len(self.sequence) 85 | 86 | 87 | class RandomRotationTransformer(BaseSequenceTransformer): 88 | """Transformer to do random rotation. 89 | 90 | # Arguments 91 | rg: Range of rotation 92 | fill_mode: 93 | """ 94 | 95 | def __init__(self, rg, fill_mode='nearest'): 96 | super().__init__() 97 | self.rg = rg 98 | self.transformation = random_rotation 99 | self.fill_mode = fill_mode 100 | 101 | def get_args(self): 102 | return [{'rg': self.rg, 103 | 'theta': np.pi / 180 * np.random.uniform(-self.rg, self.rg), 104 | 'fill_mode': self.fill_mode} for _ in range(self.batch_size)] 105 | 106 | 107 | class RandomShiftTransformer(BaseSequenceTransformer): 108 | """Transformer to do random shift. 109 | 110 | # Arguments 111 | wrg: Width shift range, as a float fraction of the width. 112 | hrg: Height shift range, as a float fraction of the height. 113 | """ 114 | 115 | def __init__(self, wrg, hrg): 116 | super().__init__() 117 | self.wrg = wrg 118 | self.hrg = hrg 119 | self.transformation = random_shift 120 | 121 | def get_args(self): 122 | return [{'tx': np.random.uniform(-self.hrg, self.hrg), 'ty': np.random.uniform(-self.wrg, self.wrg), 123 | 'wrg': self.wrg, 'hrg': self.hrg} for _ in range(self.batch_size)] 124 | 125 | 126 | class RandomZoomTransformer(BaseSequenceTransformer): 127 | """Transformer to do random zoom. 128 | 129 | # Arguments 130 | zoom_range: Tuple of floats; zoom range for width and height. 131 | """ 132 | 133 | def __init__(self, zoom_range): 134 | super().__init__() 135 | self.zoom_range = zoom_range 136 | self.transformation = random_zoom 137 | 138 | def get_args(self): 139 | if self.zoom_range[0] == 1 and self.zoom_range[1] == 1: 140 | dt = [(1, 1) for _ in range(self.batch_size)] 141 | else: 142 | dt = [np.random.uniform(self.zoom_range[0], self.zoom_range[1], 2) for _ in range(self.batch_size)] 143 | return [{'z_known': d, 'zoom_range': self.zoom_range} for d in dt] 144 | 145 | 146 | class RandomChannelShiftTransformer(BaseSequenceTransformer): 147 | """Transformer to do random zoom. 148 | 149 | # Arguments 150 | intensity: float, intensity range 151 | """ 152 | 153 | def __init__(self, intensity): 154 | super().__init__() 155 | self.intensity = intensity 156 | self.transformation = random_channel_shift 157 | self.common_args = {'channel_axis': self.channel_axis - 1} 158 | 159 | def get_args(self): 160 | return [{'known_intensity': np.random.uniform(-self.intensity, self.intensity), 'intensity': self.intensity} for 161 | _ in range(self.batch_size)] 162 | 163 | 164 | class RandomShearTransformer(BaseSequenceTransformer): 165 | """Transformer to do random shear. 166 | 167 | # Arguments 168 | intensity: float, maximum shear. 169 | """ 170 | 171 | def __init__(self, intensity): 172 | super().__init__() 173 | self.intensity = intensity 174 | self.transformation = random_shear 175 | 176 | def get_args(self): 177 | return [{'known_intensity': np.random.uniform(-self.intensity, self.intensity), 'intensity': self.intensity} for 178 | _ in range(self.batch_size)] 179 | 180 | 181 | class RandomHorizontalFlipTransformer(BaseSequenceTransformer): 182 | """Transformer to do random horizontal flip.""" 183 | def __init__(self): 184 | super().__init__() 185 | self.transformation = flip_horizontal 186 | # The -1 is important here! 187 | self.common_args = {'col_axis': self.col_axis - 1} 188 | 189 | def get_args(self): 190 | return [{'value': np.random.random()} for 191 | _ in range(self.batch_size)] 192 | 193 | 194 | class RandomVerticalFlipTransformer(BaseSequenceTransformer): 195 | """Transformer to do random vertical flip.""" 196 | def __init__(self): 197 | super().__init__() 198 | self.transformation = flip_vertical 199 | # The -1 is important here! 200 | self.common_args = {'row_axis': self.row_axis - 1} 201 | 202 | def get_args(self): 203 | return [{'value': np.random.random()} for 204 | _ in range(self.batch_size)] 205 | --------------------------------------------------------------------------------