├── .bumpversion.cfg ├── .github └── workflows │ ├── develop.yml │ ├── master-build-dist.yml │ └── release.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── Makefile ├── README.md ├── assets └── images │ ├── policy_v0 │ ├── v0_0.png │ ├── v0_1.png │ ├── v0_2.png │ ├── v0_3.png │ └── v0_4.png │ ├── policy_v1 │ ├── v1_0.png │ ├── v1_1.png │ ├── v1_10.png │ ├── v1_11.png │ ├── v1_12.png │ ├── v1_13.png │ ├── v1_14.png │ ├── v1_15.png │ ├── v1_16.png │ ├── v1_17.png │ ├── v1_18.png │ ├── v1_19.png │ ├── v1_2.png │ ├── v1_3.png │ ├── v1_4.png │ ├── v1_5.png │ ├── v1_6.png │ ├── v1_7.png │ ├── v1_8.png │ └── v1_9.png │ ├── policy_v2 │ ├── v2_0.png │ ├── v2_1.png │ ├── v2_10.png │ ├── v2_11.png │ ├── v2_12.png │ ├── v2_13.png │ ├── v2_14.png │ ├── v2_2.png │ ├── v2_3.png │ ├── v2_4.png │ ├── v2_5.png │ ├── v2_6.png │ ├── v2_7.png │ ├── v2_8.png │ └── v2_9.png │ └── policy_v3 │ ├── v3_0.png │ ├── v3_1.png │ ├── v3_10.png │ ├── v3_11.png │ ├── v3_12.png │ ├── v3_13.png │ ├── v3_14.png │ ├── v3_2.png │ ├── v3_3.png │ ├── v3_4.png │ ├── v3_5.png │ ├── v3_6.png │ ├── v3_7.png │ ├── v3_8.png │ └── v3_9.png ├── bbaug ├── __init__.py ├── _version.py ├── augmentations │ ├── __init__.py │ └── augmentations.py ├── exceptions.py ├── policies │ ├── __init__.py │ └── policies.py └── visuals │ ├── __init__.py │ └── visuals.py ├── conftest.py ├── coverage.svg ├── notebooks ├── custom_augmentations.ipynb ├── data │ └── example_dataset │ │ ├── boxes │ │ └── dog_1.txt │ │ └── images │ │ └── dog_1.jpg └── example_run.ipynb ├── pyproject.toml ├── setup.py └── tests ├── test_augmentations.py ├── test_policies.py └── test_visuals.py /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.4.2 3 | commit = True 4 | tag = True 5 | parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\-(?P[a-z]+).(?P\d+))? 6 | serialize = 7 | {major}.{minor}.{patch}-{release}.{build} 8 | {major}.{minor}.{patch} 9 | 10 | [bumpversion:part:release] 11 | optional_value = prod 12 | first_value = beta 13 | values = 14 | beta 15 | rc 16 | prod 17 | 18 | [bumpversion:part:build] 19 | 20 | [bumpversion:file:.bumpversion.cfg] 21 | 22 | [bumpversion:file:setup.py] 23 | 24 | [bumpversion:file:pyproject.toml] 25 | 26 | [bumpversion:file:bbaug/_version.py] 27 | -------------------------------------------------------------------------------- /.github/workflows/develop.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Develop Branch CI 5 | 6 | on: 7 | push: 8 | branches: [ develop ] 9 | pull_request: 10 | branches: [ develop ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.6, 3.7, 3.8] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v1 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install imgaug 30 | pip install numpy 31 | - name: Lint with flake8 32 | run: | 33 | pip install flake8 34 | # stop the build if there are Python syntax errors or undefined names 35 | flake8 ./bbaug --count --select=E9,F63,F7,F82 --show-source --statistics 36 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 37 | flake8 /bbaug. --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 38 | - name: Test with pytest 39 | run: | 40 | pip install pytest 41 | pip install pytest-mock 42 | pytest 43 | -------------------------------------------------------------------------------- /.github/workflows/master-build-dist.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Master Branch CI 5 | 6 | on: 7 | push: 8 | branches: 9 | - master 10 | tags-ignore: 11 | - '*beta*' 12 | - '*rc*' 13 | pull_request: 14 | branches: 15 | - master 16 | jobs: 17 | build-dist: 18 | runs-on: ubuntu-latest 19 | strategy: 20 | matrix: 21 | python-version: [3.6] 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v1 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install imgaug 32 | pip install numpy 33 | - name: Lint with flake8 34 | run: | 35 | pip install flake8 36 | # stop the build if there are Python syntax errors or undefined names 37 | flake8 ./bbaug --count --select=E9,F63,F7,F82 --show-source --statistics 38 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 39 | flake8 /bbaug. --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 40 | - name: Test with pytest 41 | run: | 42 | pip install pytest 43 | pip install pytest-mock 44 | pytest -vv -s 45 | - name: Build dist 46 | if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags') 47 | run: | 48 | pip install --user --upgrade setuptools wheel 49 | python setup.py sdist bdist_wheel 50 | - name: Publish package to test PyPi 51 | if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags') 52 | uses: pypa/gh-action-pypi-publish@master 53 | with: 54 | user: __token__ 55 | password: ${{ secrets.PYPI_TOKEN }} 56 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Release Branch CI 5 | 6 | on: 7 | push: 8 | branches: 9 | - release 10 | tags: 11 | - '*rc*' 12 | pull_request: 13 | branches: 14 | - release 15 | jobs: 16 | build: 17 | runs-on: ubuntu-latest 18 | strategy: 19 | matrix: 20 | python-version: [3.6] 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v1 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install imgaug 32 | pip install numpy 33 | - name: Lint with flake8 34 | run: | 35 | pip install flake8 36 | # stop the build if there are Python syntax errors or undefined names 37 | flake8 ./bbaug --count --select=E9,F63,F7,F82 --show-source --statistics 38 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 39 | flake8 /bbaug. --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 40 | - name: Test with pytest 41 | run: | 42 | pip install pytest 43 | pip install pytest-mock 44 | pytest -vv -s 45 | - name: Build dist 46 | if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags') 47 | run: | 48 | pip install --user --upgrade setuptools wheel 49 | python setup.py sdist bdist_wheel 50 | - name: Publish package to test PyPi 51 | if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags') 52 | uses: pypa/gh-action-pypi-publish@master 53 | with: 54 | user: __token__ 55 | password: ${{ secrets.PYPI_TEST_TOKEN }} 56 | repository_url: https://test.pypi.org/legacy/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # IDE 132 | .idea 133 | 134 | # poetry 135 | poetry.lock -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # 0.4.2 2 | - Fixed bug where bounding box specific augmentations would not be clipped or removed 3 | 4 | # 0.4.1 5 | - Now possible to pass a random state to the policy container ensuring reproducible augmentations 6 | 7 | # 0.4.0 8 | - Apply augmentations to bounding boxes individually 9 | - Fixed bug in `visualise_policy` 10 | 11 | # 0.3.0 12 | - Class labels are now required for bounding boxes 13 | 14 | # 0.2.1 15 | - Fixed bug where the cutout would be larger than the image 16 | 17 | # 0.2.0 18 | - Implementation of policy version 0, 1 and 2 19 | - Module to aid in the visualisation of a policy 20 | - Notebooks for bbaug integration into training model and custom policies 21 | 22 | # 0.1.0 23 | - Implementation of policy version 3 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Harpal Sahota 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help 2 | 3 | help: ## Shows this help 4 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 5 | 6 | test: ## Run the tests 7 | @pytest ./tests -vv -s; 8 | 9 | test-report: ## Run the tests and return a coverage report 10 | @pytest --cov-report term-missing --cov=bbaug tests/ 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Master Branch Dist CI](https://github.com/harpalsahota/bbaug/workflows/Master%20Branch%20Dist%20CI/badge.svg?branch=master) ![Alt text](./coverage.svg) [![PyPI version](https://badge.fury.io/py/bbaug.svg)](https://badge.fury.io/py/bbaug) [![Downloads](https://pepy.tech/badge/bbaug)](https://pepy.tech/project/bbaug) 2 | 3 | 4 | # BBAug 5 | 6 | BBAug is a Python package for the implementation of Google’s Brain Team’s bounding box augmentation policies. 7 | The package is aimed for PyTorch users who wish to use these policies in the augmentation of bounding boxes during the 8 | training of a model. Currently all 4 versions of the policies are implemented. This package builds on top of the excellent 9 | image augmentations package [imgaug](https://github.com/aleju/imgaug). 10 | 11 | **References** 12 | - [Paper](https://arxiv.org/abs/1906.11172) 13 | - [Tensorflow Policy Code](https://github.com/tensorflow/tpu/blob/2264f53d95852efbfb82ea27f03ca749e1205968/models/official/detection/utils/autoaugment_utils.py) 14 | 15 | ## Features 16 | 17 | - [x] Implementation of all 4 policies 18 | - [x] Custom policies 19 | - [x] Custom augmentations 20 | - [x] Bounding boxes are removed if they fall outside of the image* 21 | - [x] Boudning boxes are clipped if they are partially outside the image* 22 | - [x] Augmentations that imply direction e.g. rotation is randomly determined 23 | 24 | *Doest not happen for bounding box specific augmentations 25 | 26 | ## To Do 27 | - [x] ~~Implementation of version 2 of policies~~ (implemented in v0.2) 28 | - [x] ~~Implementation of version 1 of policies~~ (implemented in v0.2) 29 | - [x] ~~For bounding box augmentations apply the probability individually for each box not collectively~~ (implemented in v0.4) 30 | 31 | ## Installation 32 | 33 | Installation is best done via pip: 34 | > pip install bbaug 35 | 36 | ### Prerequisites 37 | - Python 3.6+ 38 | - PyTorch 39 | - Torchvision 40 | 41 | ## Description and Usage 42 | 43 | For detailed description on usage please refer to the Python notebooks provided in the `notebooks` folder. 44 | 45 | A augmentation is define by 3 attributes: 46 | - **Name**: Name of the augmentation 47 | - **Probability**: Probability of augmentation being applied 48 | - **Magnitude**: The degree of the augmentation (values are integers between 0 and 10) 49 | 50 | A `sub-policy` is a collection of augmentations: e.g. 51 | ```python 52 | sub_policy = [('translation', 0.5, 1), ('rotation', 1.0, 9)] 53 | ``` 54 | In the above example we have two augmentations in a sub-policy. The `translation` augmentation has a 55 | probability of 0.5 and a magnitude of 1, whereas the `rotation` augmentation has a probability of 1.0 and a 56 | magnitude of 9. The magnitudes do not directly translate into the augmentation policy i.e. a magnitude of 9 57 | does not mean a 9 degrees rotation. Instead, scaling is applied to the magnitude to determine the value passed 58 | to the augmentation method. The scaling varies depending on the augmentation used. 59 | 60 | A `policy` is a set of sub-policies: 61 | ```python 62 | policies = [ 63 | [('translation', 0.5, 1), ('rotation', 1.0, 9)], 64 | [('colour', 0.5, 1), ('cutout', 1.0, 9)], 65 | [('rotation', 0.5, 1), ('solarize', 1.0, 9)] 66 | ] 67 | ``` 68 | During training, a random policy is selected from the list of sub-policies and applied to the image and because 69 | each augmentation has it's own probability this adds a degree of stochasticity to training. 70 | 71 | ### Augmentations 72 | 73 | Each augmentation contains a string referring to the name of the augmentation. The `augmentations` module 74 | contains a dictionary mapping the name to a method reference of the augmentation. 75 | ```python 76 | from bbaug.augmentations import NAME_TO_AUGMENTATION 77 | print(NAME_TO_AUGMENTATION) # Shows the dictionary of the augmentation name to the method reference 78 | ``` 79 | Some augmentations are applied only to the bounding boxes. Augmentations which have the suffix `BBox` are only 80 | applied to the bounding boxes in the image. 81 | 82 | #### Listing All Policies Available 83 | To obtain a list of all available polices run the `list_policies` method. This will return a list of strings 84 | containing the function names for the policy sets. 85 | ```python 86 | from bbaug.policies import list_policies 87 | print(list_policies()) # List of policies available 88 | ``` 89 | 90 | #### Listing the policies in a policy set 91 | ```python 92 | from bbaug.policies import policies_v3 93 | print(policies_v3()) # Will list all the polices in version 3 94 | ``` 95 | 96 | #### Visualising a Policy 97 | 98 | To visulaise a policy on a single image a `visualise_policy` method is available in the `visuals` module. 99 | 100 | ```python 101 | from bbaug.visuals import visualise_policy 102 | visualise_policy( 103 | 'path/to/image', 104 | 'save/dir/of/augmentations', 105 | bounding_boxes, # Bounding boxes is a list of list of bounding boxes in pixels (int): e.g. [[x_min, y_min, x_man, y_max], [x_min, y_min, x_man, y_max]] 106 | labels, # Class labels for the bounding boxes as an iterable of ints eg. [0, 5] 107 | policy, # the policy to visualise 108 | name_to_augmentation, # (optional, default: augmentations.NAME_TO_AUGMENTATION) The dictionary mapping the augmentation name to the augmentation method 109 | ) 110 | ``` 111 | 112 | #### Policy Container 113 | To help integrate the policies into training a `PolicyContainer` class available in the `policies` 114 | module. The container accepts the following inputs: 115 | - **policy_set** (required): The policy set to use 116 | - **name_to_augmentation** (optional, default: `augmentations.NAME_TO_AUGMENTATION`): The dictionary mapping the augmentation name to the augmentation method 117 | - **return_yolo** (optional, default: `False`): Return the bounding boxes in YOLO format otherwise `[x_min, y_min, x_man, y_max]` in pixels is returned 118 | 119 | Usage of the policy container: 120 | ```python 121 | from bbaug import policies 122 | 123 | # select policy v3 set 124 | aug_policy = policies.policies_v3() 125 | 126 | # instantiate the policy container with the selected policy set 127 | policy_container = policies.PolicyContainer(aug_policy) 128 | 129 | # select a random policy from the policy set 130 | random_policy = policy_container.select_random_policy() 131 | 132 | # Apply the augmentation. Returns the augmented image and bounding boxes. 133 | # Image is a numpy array of the image 134 | # Bounding boxes is a list of list of bounding boxes in pixels (int). 135 | # e.g. [[x_min, y_min, x_man, y_max], [x_min, y_min, x_max, y_max]] 136 | # Labels are the class labels for the bounding boxes as an iterable of ints e.g. [1,0] 137 | img_aug, bbs_aug = policy_container.apply_augmentation(random_policy, image, bounding_boxes, labels) 138 | # image_aug: numpy array of the augmented image 139 | # bbs_aug: numpy array of augmneted bounding boxes in format: [[label, x_min, y_min, x_man, y_max],...] 140 | ``` 141 | ## Policy Implementation 142 | The policies implemented in `bbaug` are shown below. Each column represents a different run for that given sub-policy 143 | as each augmentation in the sub-policy has it's own probability this results in variations between runs. 144 | 145 | #### Version 0 146 | These are the policies used in the paper. 147 | 148 | ![image](assets/images/policy_v0/v0_0.png) 149 | ![image](assets/images/policy_v0/v0_1.png) 150 | ![image](assets/images/policy_v0/v0_2.png) 151 | ![image](assets/images/policy_v0/v0_3.png) 152 | ![image](assets/images/policy_v0/v0_4.png) 153 | #### Version 1 154 | ![image](assets/images/policy_v1/v1_0.png) 155 | ![image](assets/images/policy_v1/v1_1.png) 156 | ![image](assets/images/policy_v1/v1_2.png) 157 | ![image](assets/images/policy_v1/v1_3.png) 158 | ![image](assets/images/policy_v1/v1_4.png) 159 | ![image](assets/images/policy_v1/v1_5.png) 160 | ![image](assets/images/policy_v1/v1_6.png) 161 | ![image](assets/images/policy_v1/v1_7.png) 162 | ![image](assets/images/policy_v1/v1_8.png) 163 | ![image](assets/images/policy_v1/v1_9.png) 164 | ![image](assets/images/policy_v1/v1_10.png) 165 | ![image](assets/images/policy_v1/v1_11.png) 166 | ![image](assets/images/policy_v1/v1_12.png) 167 | ![image](assets/images/policy_v1/v1_13.png) 168 | ![image](assets/images/policy_v1/v1_14.png) 169 | ![image](assets/images/policy_v1/v1_15.png) 170 | ![image](assets/images/policy_v1/v1_16.png) 171 | ![image](assets/images/policy_v1/v1_17.png) 172 | ![image](assets/images/policy_v1/v1_18.png) 173 | ![image](assets/images/policy_v1/v1_19.png) 174 | #### Version 2 175 | ![image](assets/images/policy_v2/v2_0.png) 176 | ![image](assets/images/policy_v2/v2_1.png) 177 | ![image](assets/images/policy_v2/v2_2.png) 178 | ![image](assets/images/policy_v2/v2_3.png) 179 | ![image](assets/images/policy_v2/v2_4.png) 180 | ![image](assets/images/policy_v2/v2_5.png) 181 | ![image](assets/images/policy_v2/v2_6.png) 182 | ![image](assets/images/policy_v2/v2_7.png) 183 | ![image](assets/images/policy_v2/v2_8.png) 184 | ![image](assets/images/policy_v2/v2_9.png) 185 | ![image](assets/images/policy_v2/v2_10.png) 186 | ![image](assets/images/policy_v2/v2_11.png) 187 | ![image](assets/images/policy_v2/v2_12.png) 188 | ![image](assets/images/policy_v2/v2_13.png) 189 | ![image](assets/images/policy_v2/v2_14.png) 190 | #### Version 3 191 | ![image](assets/images/policy_v3/v3_0.png) 192 | ![image](assets/images/policy_v3/v3_1.png) 193 | ![image](assets/images/policy_v3/v3_2.png) 194 | ![image](assets/images/policy_v3/v3_3.png) 195 | ![image](assets/images/policy_v3/v3_4.png) 196 | ![image](assets/images/policy_v3/v3_5.png) 197 | ![image](assets/images/policy_v3/v3_6.png) 198 | ![image](assets/images/policy_v3/v3_7.png) 199 | ![image](assets/images/policy_v3/v3_8.png) 200 | ![image](assets/images/policy_v3/v3_9.png) 201 | ![image](assets/images/policy_v3/v3_10.png) 202 | ![image](assets/images/policy_v3/v3_11.png) 203 | ![image](assets/images/policy_v3/v3_12.png) 204 | ![image](assets/images/policy_v3/v3_13.png) 205 | ![image](assets/images/policy_v3/v3_14.png) 206 | -------------------------------------------------------------------------------- /assets/images/policy_v0/v0_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v0/v0_0.png -------------------------------------------------------------------------------- /assets/images/policy_v0/v0_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v0/v0_1.png -------------------------------------------------------------------------------- /assets/images/policy_v0/v0_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v0/v0_2.png -------------------------------------------------------------------------------- /assets/images/policy_v0/v0_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v0/v0_3.png -------------------------------------------------------------------------------- /assets/images/policy_v0/v0_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v0/v0_4.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_0.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_1.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_10.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_11.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_12.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_13.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_14.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_15.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_16.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_17.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_18.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_19.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_2.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_3.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_4.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_5.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_6.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_7.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_8.png -------------------------------------------------------------------------------- /assets/images/policy_v1/v1_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v1/v1_9.png -------------------------------------------------------------------------------- /assets/images/policy_v2/v2_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v2/v2_0.png -------------------------------------------------------------------------------- /assets/images/policy_v2/v2_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v2/v2_1.png -------------------------------------------------------------------------------- /assets/images/policy_v2/v2_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v2/v2_10.png -------------------------------------------------------------------------------- /assets/images/policy_v2/v2_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v2/v2_11.png -------------------------------------------------------------------------------- /assets/images/policy_v2/v2_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v2/v2_12.png -------------------------------------------------------------------------------- /assets/images/policy_v2/v2_13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v2/v2_13.png -------------------------------------------------------------------------------- /assets/images/policy_v2/v2_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v2/v2_14.png -------------------------------------------------------------------------------- /assets/images/policy_v2/v2_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v2/v2_2.png -------------------------------------------------------------------------------- /assets/images/policy_v2/v2_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v2/v2_3.png -------------------------------------------------------------------------------- /assets/images/policy_v2/v2_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v2/v2_4.png -------------------------------------------------------------------------------- /assets/images/policy_v2/v2_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v2/v2_5.png -------------------------------------------------------------------------------- /assets/images/policy_v2/v2_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v2/v2_6.png -------------------------------------------------------------------------------- /assets/images/policy_v2/v2_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v2/v2_7.png -------------------------------------------------------------------------------- /assets/images/policy_v2/v2_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v2/v2_8.png -------------------------------------------------------------------------------- /assets/images/policy_v2/v2_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v2/v2_9.png -------------------------------------------------------------------------------- /assets/images/policy_v3/v3_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v3/v3_0.png -------------------------------------------------------------------------------- /assets/images/policy_v3/v3_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v3/v3_1.png -------------------------------------------------------------------------------- /assets/images/policy_v3/v3_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v3/v3_10.png -------------------------------------------------------------------------------- /assets/images/policy_v3/v3_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v3/v3_11.png -------------------------------------------------------------------------------- /assets/images/policy_v3/v3_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v3/v3_12.png -------------------------------------------------------------------------------- /assets/images/policy_v3/v3_13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v3/v3_13.png -------------------------------------------------------------------------------- /assets/images/policy_v3/v3_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v3/v3_14.png -------------------------------------------------------------------------------- /assets/images/policy_v3/v3_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v3/v3_2.png -------------------------------------------------------------------------------- /assets/images/policy_v3/v3_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v3/v3_3.png -------------------------------------------------------------------------------- /assets/images/policy_v3/v3_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v3/v3_4.png -------------------------------------------------------------------------------- /assets/images/policy_v3/v3_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v3/v3_5.png -------------------------------------------------------------------------------- /assets/images/policy_v3/v3_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v3/v3_6.png -------------------------------------------------------------------------------- /assets/images/policy_v3/v3_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v3/v3_7.png -------------------------------------------------------------------------------- /assets/images/policy_v3/v3_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v3/v3_8.png -------------------------------------------------------------------------------- /assets/images/policy_v3/v3_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/assets/images/policy_v3/v3_9.png -------------------------------------------------------------------------------- /bbaug/__init__.py: -------------------------------------------------------------------------------- 1 | from ._version import VERSION 2 | __version__ = VERSION 3 | -------------------------------------------------------------------------------- /bbaug/_version.py: -------------------------------------------------------------------------------- 1 | VERSION = '0.4.2' 2 | -------------------------------------------------------------------------------- /bbaug/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from .augmentations import * # noqa 2 | -------------------------------------------------------------------------------- /bbaug/augmentations/augmentations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module for object detection augmentation using google policies 3 | Paper: https://arxiv.org/abs/1906.11172 4 | """ 5 | from functools import wraps 6 | 7 | from imgaug import augmenters as iaa 8 | import numpy as np 9 | 10 | from bbaug.exceptions import InvalidMagnitude 11 | 12 | _MAX_MAGNITUDE = 10.0 13 | BBOX_TRANSLATION = 120 14 | CUTOUT_BBOX = 50 15 | CUTOUT_MAX_PAD_FRACTION = 0.75 16 | CUTOUT_CONST = 100 17 | TRANSLATION_CONST = 250 18 | 19 | __all__ = [ 20 | 'negate', 21 | 'NAME_TO_AUGMENTATION', 22 | 'auto_contrast', 23 | 'brightness', 24 | 'colour', 25 | 'contrast', 26 | 'cutout', 27 | 'cutout_bbox', 28 | 'cutout_fraction', 29 | 'equalise', 30 | 'fliplr_boxes', 31 | 'posterize', 32 | 'rotate', 33 | 'sharpness', 34 | 'shear_x', 35 | 'shear_y', 36 | 'solarize', 37 | 'solarize_add', 38 | 'translate_x', 39 | 'translate_x_bbox', 40 | 'translate_y', 41 | 'translate_y_bbox', 42 | ] 43 | 44 | 45 | def negate(func): 46 | """ 47 | Wrapper function to randomly reverse the direction of an augmentation 48 | 49 | :param func: Augmentation function 50 | :return: func 51 | """ 52 | @wraps(func) 53 | def wrapper(*args, **kwargs): 54 | if np.random.random() < 0.5: 55 | return -func(*args, **kwargs) 56 | return func(*args, **kwargs) 57 | return wrapper 58 | 59 | 60 | def validate_magnitude(func): 61 | """ 62 | Wrapper func to ensure magnitude is within the expected range 63 | 64 | :param func: func to test magnitude of 65 | :return: func 66 | """ 67 | @wraps(func) 68 | def wrapper(*args, **kwargs): 69 | magnitude = args[0] 70 | if (magnitude < 0) or (magnitude > 10): 71 | raise InvalidMagnitude( 72 | f'Magnitude should be > 0 and < 10. Actual value: {magnitude}' 73 | ) 74 | return func(*args, **kwargs) 75 | return wrapper 76 | 77 | 78 | @validate_magnitude 79 | def _img_enhance_to_arg(magnitude: int) -> float: 80 | """ 81 | Determine the magnitude of the image enhancement 82 | 83 | :type magnitude: int 84 | :param magnitude: Magnitude of enhancement 85 | :rtype: float 86 | :return: Magnitude of enhancement to apply 87 | """ 88 | return (magnitude / _MAX_MAGNITUDE) * 1.8 + 0.1 89 | 90 | 91 | @negate 92 | @validate_magnitude 93 | def _rotate_mag_to_arg(magnitude: int) -> float: 94 | """ 95 | Determine rotation magnitude 96 | 97 | :type magnitude: int 98 | :param magnitude: Magnitude of rotation 99 | :rtype: float 100 | :return: Rotation in degrees 101 | """ 102 | return (magnitude / _MAX_MAGNITUDE) * 30 103 | 104 | 105 | @negate 106 | @validate_magnitude 107 | def _shear_mag_to_arg(magnitude: int) -> float: 108 | """ 109 | Determine shear magnitude 110 | 111 | :type magnitude: int 112 | :param magnitude: magnitude of shear 113 | :rtype: float 114 | :return: shear magnitude 115 | """ 116 | return (magnitude / _MAX_MAGNITUDE) * 0.3 117 | 118 | 119 | @negate 120 | @validate_magnitude 121 | def _translate_mag_to_arg(magnitude: int, bbox=False) -> int: 122 | """ 123 | Determine translation magnitude in pixels 124 | 125 | :type magnitude: int 126 | :param magnitude: Magnitude of translation 127 | :rtype: int 128 | :return: Translation in pixels 129 | """ 130 | if bbox: 131 | return int((magnitude / _MAX_MAGNITUDE) * BBOX_TRANSLATION) 132 | return int((magnitude / _MAX_MAGNITUDE) * TRANSLATION_CONST) 133 | 134 | 135 | def auto_contrast(_: int) -> iaa.pillike.Autocontrast: 136 | """ 137 | Apply auto contrast to image 138 | 139 | Tensorflow Policy Equivalent: autocontrast 140 | 141 | :type _: int 142 | :param _: unused magnitude 143 | :rtype: iaa.pillike.Autocontrast 144 | :return: Method to auto contrast image 145 | """ 146 | return iaa.pillike.Autocontrast(0) 147 | 148 | 149 | def brightness(magnitude: int) -> iaa.pillike.EnhanceBrightness: 150 | """ 151 | Adjust the brightness of an image 152 | 153 | Tensorflow Policy Equivalent: brightness 154 | 155 | :type magnitude: int 156 | :param magnitude: Magnitude of brightness change 157 | :rtype: iaa.pillike.EnhanceBrightness 158 | :return: Method to adjust brightness in image 159 | """ 160 | level = _img_enhance_to_arg(magnitude) 161 | return iaa.pillike.EnhanceBrightness(level) 162 | 163 | 164 | def colour(magnitude: int) -> iaa.pillike.EnhanceColor: 165 | """ 166 | Adjust the brightness of an image 167 | 168 | Tensorflow Policy Equivalent: color 169 | 170 | :type magnitude: int 171 | :param magnitude: Magnitude of colour change 172 | :rtype: iaa.pillike.EnhanceColor 173 | :return: Method to adjust colour in image 174 | """ 175 | level = _img_enhance_to_arg(magnitude) 176 | return iaa.pillike.EnhanceColor(level) 177 | 178 | 179 | def contrast(magnitude: int) -> iaa.GammaContrast: 180 | """ 181 | Adjust the contrast of an image 182 | 183 | Tensorflow Policy Equivalent: contrast 184 | 185 | :type magnitude: int 186 | :param magnitude: magnitude of contrast change 187 | :rtype: iaa.GammaContrast 188 | :return: Method to adjust contrast of image 189 | """ 190 | level = _img_enhance_to_arg(magnitude) 191 | return iaa.GammaContrast(level) 192 | 193 | 194 | @validate_magnitude 195 | def cutout(magnitude: int, **kwargs) -> iaa.Cutout: 196 | """ 197 | Apply cutout anywhere in the image. Passing the height and width 198 | of the image as integers and as keywords will scale the bounding 199 | box according to the policy 200 | 201 | Tensorflow Policy Equivalent: cutout 202 | 203 | The cutout value in the policies is at a pixel level. The imgaug cutout 204 | augmentation method requires the cutout to be a percentage of the image. 205 | Passing the image height and width as kwargs will scale the cutout to 206 | the appropriate percentage. Otherwise the imgaug default of 20% will be 207 | used. 208 | 209 | :type magnitude: int 210 | :param magnitude: magnitude of cutout 211 | :param kwargs: 212 | height: height of the image as int 213 | width: width of the image as int 214 | :rtype: iaa.Cutout 215 | :return: Method to apply cutout to image 216 | """ 217 | level = int((magnitude / _MAX_MAGNITUDE) * CUTOUT_CONST) 218 | cutout_args = {} 219 | if 'height' in kwargs and 'width' in kwargs: 220 | size = tuple(np.clip([ 221 | (level / kwargs['height']) * 2, 222 | (level / kwargs['width']) * 2 223 | ], 0.0, 1.0)) 224 | cutout_args['size'] = size 225 | return iaa.Cutout(**cutout_args) 226 | 227 | 228 | @validate_magnitude 229 | def cutout_fraction(magnitude: int, **kwargs) -> iaa.Cutout: 230 | """ 231 | Applies cutout to the image according to bbox information. This will 232 | apply only to a single bounding box in the image. For the augmentation 233 | to apply the policy correctly the image height and width along with the 234 | bounding box height and width are required as keyword arguments. 235 | 236 | Tensorflow Policy Equivalent: bbox_cutout 237 | 238 | The cutout size is determined as a fraction of the bounding box size. 239 | The cutout value in the policies is at a pixel level. The imgaug cutout 240 | augmentation method requires the cutout to be a percentage of the image. 241 | Passing the image height and width as kwargs will scale the cutout to the 242 | appropriate percentage. Otherwise the imgaug default of 20% will be used. 243 | 244 | Note: the cutout may not always be present in the bounding box dut to 245 | randomness in the location of the cutout centre 246 | 247 | :type magnitude: int 248 | :param magnitude: magnitude of cutout 249 | :param kwargs: 250 | height: height of the image as int 251 | width: width of the image as int 252 | height_bbox: height of the bounding box as int 253 | width_bbox: width of the bounding box as int 254 | :rtype: iaa.Cutout 255 | :return: Method to apply cutout to bounding boxes 256 | """ 257 | level = (magnitude / _MAX_MAGNITUDE) * CUTOUT_MAX_PAD_FRACTION 258 | cutout_args = {} 259 | if all( 260 | i in kwargs 261 | for i in ['height', 'width', 'height_bbox', 'width_bbox'] 262 | ): 263 | size = tuple([ 264 | (level * kwargs['height_bbox']) / kwargs['height'], 265 | (level * kwargs['width_bbox']) / kwargs['width'] 266 | ]) 267 | cutout_args['size'] = size 268 | return iaa.Cutout(**cutout_args) 269 | 270 | 271 | def cutout_bbox(magnitude: int, **kwargs) -> iaa.BlendAlphaBoundingBoxes: 272 | """ 273 | Only apply cutout to the bounding box area. Passing the 274 | height and width of the image as integers and as keywords 275 | will scale the bounding box according to the policy. Note, the 276 | cutout location is chosen randomly and will only appear if it 277 | falls within the bounding box. 278 | 279 | :type magnitude: int 280 | :param magnitude: magnitude of cutout 281 | :param kwargs: 282 | height: height of the image as int 283 | width: width of the image as int 284 | :rtype: iaa.BlendAlphaBoundingBoxes 285 | :return: Method to apply cutout only to bounding boxes 286 | """ 287 | level = int((magnitude/_MAX_MAGNITUDE) * CUTOUT_BBOX) 288 | cutout_args = {} 289 | if 'height' in kwargs and 'width' in kwargs: 290 | size = tuple(np.clip([ 291 | level / kwargs['height'], 292 | level / kwargs['width'] 293 | ], 0.0, 1.0)) 294 | cutout_args['size'] = size 295 | return iaa.BlendAlphaBoundingBoxes( 296 | None, 297 | foreground=iaa.Cutout(**cutout_args) 298 | ) 299 | 300 | 301 | def equalise(_: int) -> iaa.AllChannelsHistogramEqualization: 302 | """ 303 | Apply auto histogram equalisation to the image 304 | 305 | Tensorflow Policy Equivalent: equalize 306 | 307 | :type _: int 308 | :param _: unused magnitude 309 | :rtype: iaa.AllChannelsHistogramEqualization 310 | :return: Method to equalise image 311 | """ 312 | return iaa.AllChannelsHistogramEqualization() 313 | 314 | 315 | def fliplr_boxes(_: int) -> iaa.BlendAlphaBoundingBoxes: 316 | """ 317 | Flip only the bounding boxes horizontally 318 | 319 | Tensorflow Policy Equivalent: flip_only_bboxes 320 | 321 | :type _: int 322 | :param _: Unused, kept to fit within the ecosystem 323 | :rtype: iaa.AllChannelsHistogramEqualization 324 | :return: Method to flip bounding boxes horizontally 325 | """ 326 | return iaa.BlendAlphaBoundingBoxes( 327 | None, 328 | foreground=iaa.Fliplr(1.0) 329 | ) 330 | 331 | 332 | @validate_magnitude 333 | def posterize(magnitude: int): 334 | """ 335 | Posterize image 336 | 337 | Tensorflow Policy Equivalent: posterize 338 | 339 | :type magnitude: int 340 | :param magnitude: magnitude of posterize 341 | :rtype: iaa.AllChannelsHistogramEqualization 342 | :return: Method to posterize image 343 | """ 344 | nbits = int((magnitude / _MAX_MAGNITUDE) * 4) 345 | if nbits == 0: 346 | nbits += 1 347 | return iaa.color.Posterize(nb_bits=nbits) 348 | 349 | 350 | def rotate(magnitude: int) -> iaa.BlendAlphaBoundingBoxes: 351 | """ 352 | Rotate the bounding box in an image 353 | 354 | Tensorflow Policy Equivalent: rotate_with_bboxes 355 | 356 | :type magnitude: int 357 | :param magnitude: magnitude of rotation 358 | :rtype: iaa.BlendAlphaBoundingBoxes 359 | :return: Method to apply rotation 360 | """ 361 | level = _rotate_mag_to_arg(magnitude) 362 | return iaa.Rotate(level) 363 | 364 | 365 | def sharpness(magnitude: int) -> iaa.pillike.EnhanceSharpness: 366 | """ 367 | Add sharpness to the image 368 | 369 | Tensorflow Policy Equivalent: sharpness 370 | 371 | :type magnitude: int 372 | :param magnitude: magnitude of sharpness 373 | :rtype: iaa.pillike.EnhanceSharpness 374 | :return: Method to adjust sharpness 375 | """ 376 | level = _img_enhance_to_arg(magnitude) 377 | return iaa.pillike.EnhanceSharpness(level) 378 | 379 | 380 | def shear_x(magnitude: int) -> iaa.ShearY: 381 | """ 382 | Apply x shear to the image and boxes 383 | 384 | Tensorflow Policy Equivalent: shear_x 385 | 386 | :type magnitude: int 387 | :param magnitude: magnitude of y shear 388 | :rtype: iaa.ShearY 389 | :return: Method to y shear bounding boxes 390 | """ 391 | level = _shear_mag_to_arg(magnitude) 392 | return iaa.ShearX(level) 393 | 394 | 395 | def shear_x_bbox(magnitude: int) -> iaa.BlendAlphaBoundingBoxes: 396 | """ 397 | Apply x shear only to bboxes 398 | 399 | Tensorflow Policy Equivalent: shear_x_only_bboxes 400 | 401 | :type magnitude: int 402 | :param magnitude: magnitude of x shear 403 | :rtype: iaa.BlendAlphaBoundingBoxes 404 | :return: Method to x shear bounding boxes 405 | """ 406 | level = _shear_mag_to_arg(magnitude) 407 | return iaa.BlendAlphaBoundingBoxes( 408 | None, 409 | foreground=iaa.ShearX(level), 410 | ) 411 | 412 | 413 | def shear_y(magnitude: int) -> iaa.ShearY: 414 | """ 415 | Apply y shear image and boxes 416 | 417 | Tensorflow Policy Equivalent: shear_y 418 | 419 | :type magnitude: int 420 | :param magnitude: magnitude of y shear 421 | :rtype: iaa.ShearY 422 | :return: Method to y shear bounding boxes 423 | """ 424 | level = _shear_mag_to_arg(magnitude) 425 | return iaa.ShearY(level) 426 | 427 | 428 | def shear_y_bbox(magnitude: int) -> iaa.BlendAlphaBoundingBoxes: 429 | """ 430 | Apply y shear only to bboxes 431 | 432 | Tensorflow Policy Equivalent: shear_y_only_bboxes 433 | 434 | :type magnitude: int 435 | :param magnitude: magnitude of y shear 436 | :rtype: iaa.BlendAlphaBoundingBoxes 437 | :return: Method to y shear bounding boxes 438 | """ 439 | level = _shear_mag_to_arg(magnitude) 440 | return iaa.BlendAlphaBoundingBoxes( 441 | None, 442 | foreground=iaa.ShearY(level), 443 | ) 444 | 445 | 446 | def solarize(_: int) -> iaa.pillike.Solarize: 447 | """ 448 | Solarize the image 449 | 450 | :type _: int 451 | :param _: Unused, kept to fit within the ecosystem 452 | :rtype: iaa.pillike.Solarize 453 | :return: Method to solarize image 454 | """ 455 | return iaa.pillike.Solarize(threshold=128) 456 | 457 | 458 | @validate_magnitude 459 | def solarize_add(magnitude: int): 460 | """ 461 | Add solarize to an image 462 | 463 | Tensorflow Policy Equivalent: solarize_add 464 | 465 | :type magnitude:int 466 | :param magnitude: Magnitude of solarization 467 | :rtype: aug 468 | :return: Method to apply solarization 469 | """ 470 | level = int((magnitude / _MAX_MAGNITUDE) * 110) 471 | 472 | def aug(image, bounding_boxes, threshold=128): 473 | image_added, image_copy = image.copy(), image.copy() 474 | image_added = image_added + level 475 | image_added = np.clip(image_added, 0, 255) 476 | image_copy[np.where(image_copy < threshold)] = image_added[np.where(image_copy < threshold)] # noqa: 501 477 | return image_copy, bounding_boxes 478 | return aug 479 | 480 | 481 | def translate_x(magnitude: int) -> iaa.geometric.TranslateX: 482 | """ 483 | Translate bounding boxes only on the x-axis 484 | 485 | Tensorflow Policy Equivalent: translate_x_only_bboxes 486 | 487 | :type magnitude: int 488 | :param magnitude: Magnitude of translation 489 | :rtype: iaa.BlendAlphaBoundingBoxes 490 | :return: Method to apply x translation to bounding boxes 491 | """ 492 | level = _translate_mag_to_arg(magnitude) 493 | return iaa.geometric.TranslateX(px=level) 494 | 495 | 496 | def translate_x_bbox(magnitude: int) -> iaa.BlendAlphaBoundingBoxes: 497 | """ 498 | Translate bounding boxes only on the x-axis 499 | 500 | Tensorflow Policy Equivalent: translate_x 501 | 502 | :type magnitude: int 503 | :param magnitude: Magnitude of translation 504 | :rtype: iaa.BlendAlphaBoundingBoxes 505 | :return: Method to apply x translation to bounding boxes 506 | """ 507 | level = _translate_mag_to_arg(magnitude, bbox=True) 508 | return iaa.BlendAlphaBoundingBoxes( 509 | None, 510 | foreground=iaa.geometric.TranslateX(px=level), 511 | ) 512 | 513 | 514 | def translate_y(magnitude: int) -> iaa.geometric.TranslateY: 515 | """ 516 | Translate bounding boxes only on the y-axis 517 | 518 | Tensorflow Policy Equivalent: translate_y_only_bboxes 519 | 520 | :type magnitude: int 521 | :param magnitude: magnitude of translation 522 | :rtype: iaa.BlendAlphaBoundingBoxes 523 | :return: Method to apply y translation to bounding boxes 524 | """ 525 | level = _translate_mag_to_arg(magnitude) 526 | return iaa.geometric.TranslateY(px=level) 527 | 528 | 529 | def translate_y_bbox(magnitude: int) -> iaa.BlendAlphaBoundingBoxes: 530 | """ 531 | Translate bounding boxes only on the y-axis 532 | 533 | Tensorflow Policy Equivalent: translate_y 534 | 535 | :type magnitude: int 536 | :param magnitude: magnitude of translation 537 | :rtype: iaa.BlendAlphaBoundingBoxes 538 | :return: Method to apply y translation to bounding boxes 539 | """ 540 | level = _translate_mag_to_arg(magnitude, bbox=True) 541 | return iaa.BlendAlphaBoundingBoxes( 542 | None, 543 | foreground=iaa.geometric.TranslateY(px=level) 544 | ) 545 | 546 | 547 | NAME_TO_AUGMENTATION = { 548 | 'Auto_Contrast': auto_contrast, 549 | 'Brightness': brightness, 550 | 'Cutout': cutout, 551 | 'Cutout_BBox': cutout_bbox, 552 | 'Cutout_Fraction': cutout_fraction, 553 | 'Color': colour, 554 | 'Contrast': contrast, 555 | 'Equalize': equalise, 556 | 'Fliplr_BBox': fliplr_boxes, 557 | 'Posterize': posterize, 558 | 'Rotate': rotate, 559 | 'Sharpness': sharpness, 560 | 'Shear_X': shear_x, 561 | 'Shear_X_BBox': shear_x_bbox, 562 | 'Shear_Y': shear_y, 563 | 'Shear_Y_BBox': shear_y_bbox, 564 | 'Solarize': solarize, 565 | 'Solarize_Add': solarize_add, 566 | 'Translate_X': translate_x, 567 | 'Translate_X_BBox': translate_x_bbox, 568 | 'Translate_Y': translate_y, 569 | 'Translate_Y_BBox': translate_y_bbox, 570 | } 571 | -------------------------------------------------------------------------------- /bbaug/exceptions.py: -------------------------------------------------------------------------------- 1 | """ Custom Exceptions """ 2 | 3 | 4 | class BaseError(Exception): 5 | """ Base Exception """ 6 | 7 | 8 | class InvalidMagnitude(BaseError): 9 | """ Error if magnitude is too large or too small """ 10 | -------------------------------------------------------------------------------- /bbaug/policies/__init__.py: -------------------------------------------------------------------------------- 1 | from .policies import * # noqa 2 | -------------------------------------------------------------------------------- /bbaug/policies/policies.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module containing augmentation policies 3 | Ref: https://github.com/tensorflow/tpu/blob/2264f53d95852efbfb82ea27f03ca749e1205968/models/official/detection/utils/autoaugment_utils.py # noqa: 501 4 | """ 5 | 6 | from collections import namedtuple 7 | import random 8 | from typing import ( 9 | Callable, 10 | Dict, 11 | Iterable, 12 | List, 13 | NamedTuple, 14 | Tuple, 15 | Union, 16 | ) 17 | 18 | from imgaug.augmentables.bbs import ( 19 | BoundingBox, 20 | BoundingBoxesOnImage, 21 | ) 22 | import numpy as np 23 | 24 | from bbaug.augmentations.augmentations import NAME_TO_AUGMENTATION 25 | 26 | POLICY_TUPLE_TYPE = NamedTuple( 27 | 'policy', 28 | [('name', str), ('probability', float), ('magnitude', str)] 29 | ) 30 | POLICY_TUPLE = namedtuple('policy', ['name', 'probability', 'magnitude']) 31 | 32 | __all__ = [ 33 | 'POLICY_TUPLE_TYPE', 34 | 'POLICY_TUPLE', 35 | 'list_policies', 36 | 'policies_v0', 37 | 'policies_v1', 38 | 'policies_v2', 39 | 'policies_v3', 40 | 'PolicyContainer', 41 | ] 42 | 43 | 44 | def policies_v0(): 45 | """ 46 | Version of the policies used in the paper 47 | ​ 48 | :rtype: List[List[POLICY_TUPLE_TYPE]] 49 | :return: List of policies 50 | """ 51 | policy = [ 52 | [ 53 | POLICY_TUPLE('Translate_X', 0.6, 4), 54 | POLICY_TUPLE('Equalize', 0.8, 10) 55 | ], 56 | [ 57 | POLICY_TUPLE('Translate_Y_BBox', 0.2, 2), 58 | POLICY_TUPLE('Cutout', 0.8, 8) 59 | ], 60 | [ 61 | POLICY_TUPLE('Sharpness', 0.0, 8), 62 | POLICY_TUPLE('Shear_X', 0.4, 0) 63 | ], 64 | [ 65 | POLICY_TUPLE('Shear_Y', 1.0, 2), 66 | POLICY_TUPLE('Translate_Y_BBox', 0.6, 6) 67 | ], 68 | [ 69 | POLICY_TUPLE('Rotate', 0.6, 10), 70 | POLICY_TUPLE('Color', 1.0, 6)], 71 | ] 72 | return policy 73 | 74 | 75 | def policies_v1() -> List[List[POLICY_TUPLE_TYPE]]: 76 | """ 77 | Version 1 of augmentation policies 78 | ​ 79 | :rtype: List[List[POLICY_TUPLE_TYPE]] 80 | :return: List of policies 81 | """ 82 | policy = [ 83 | [ 84 | POLICY_TUPLE('Translate_X', 0.6, 4), 85 | POLICY_TUPLE('Equalize', 0.8, 10) 86 | ], 87 | [ 88 | POLICY_TUPLE('Translate_Y_BBox', 0.2, 2), 89 | POLICY_TUPLE('Cutout', 0.8, 8)], 90 | [ 91 | POLICY_TUPLE('Sharpness', 0.0, 8), 92 | POLICY_TUPLE('Shear_X', 0.4, 0) 93 | ], 94 | [ 95 | POLICY_TUPLE('Shear_Y', 1.0, 2), 96 | POLICY_TUPLE('Translate_Y_BBox', 0.6, 6) 97 | ], 98 | [ 99 | POLICY_TUPLE('Rotate', 0.6, 10), 100 | POLICY_TUPLE('Color', 1.0, 6) 101 | ], 102 | [ 103 | POLICY_TUPLE('Color', 0.0, 0), 104 | POLICY_TUPLE('Shear_X_BBox', 0.8, 4) 105 | ], 106 | [ 107 | POLICY_TUPLE('Shear_Y_BBox', 0.8, 2), 108 | POLICY_TUPLE('Fliplr_BBox', 0.0, 10) 109 | ], 110 | [ 111 | POLICY_TUPLE('Equalize', 0.6, 10), 112 | POLICY_TUPLE('Translate_X', 0.2, 2) 113 | ], 114 | [ 115 | POLICY_TUPLE('Color', 1.0, 10), 116 | POLICY_TUPLE('Translate_Y_BBox', 0.4, 6) 117 | ], 118 | [ 119 | POLICY_TUPLE('Rotate', 0.8, 10), 120 | POLICY_TUPLE('Contrast', 0.0, 10) 121 | ], 122 | [ 123 | POLICY_TUPLE('Cutout', 0.2, 2), 124 | POLICY_TUPLE('Brightness', 0.8, 10) 125 | ], 126 | [ 127 | POLICY_TUPLE('Color', 1.0, 6), 128 | POLICY_TUPLE('Equalize', 1.0, 2) 129 | ], 130 | [ 131 | POLICY_TUPLE('Cutout_BBox', 0.4, 6), 132 | POLICY_TUPLE('Translate_Y_BBox', 0.8, 2) 133 | ], 134 | [ 135 | POLICY_TUPLE('Color', 0.2, 8), 136 | POLICY_TUPLE('Rotate', 0.8, 10) 137 | ], 138 | [ 139 | POLICY_TUPLE('Sharpness', 0.4, 4), 140 | POLICY_TUPLE('Translate_Y_BBox', 0.0, 4) 141 | ], 142 | [ 143 | POLICY_TUPLE('Sharpness', 1.0, 4), 144 | POLICY_TUPLE('Solarize_Add', 0.4, 4) 145 | ], 146 | [ 147 | POLICY_TUPLE('Rotate', 1.0, 8), 148 | POLICY_TUPLE('Sharpness', 0.2, 8) 149 | ], 150 | [ 151 | POLICY_TUPLE('Shear_Y', 0.6, 10), 152 | POLICY_TUPLE('Translate_Y_BBox', 0.6, 8) 153 | ], 154 | [ 155 | POLICY_TUPLE('Shear_X', 0.2, 6), 156 | POLICY_TUPLE('Translate_Y_BBox', 0.2, 10) 157 | ], 158 | [ 159 | POLICY_TUPLE('Solarize_Add', 0.6, 8), 160 | POLICY_TUPLE('Brightness', 0.8, 10) 161 | ], 162 | ] 163 | return policy 164 | 165 | 166 | def policies_v2() -> List[List[POLICY_TUPLE_TYPE]]: 167 | """ 168 | Version 2 of augmentation policies 169 | ​ 170 | :rtype: List[List[POLICY_TUPLE_TYPE]] 171 | :return: List of policies 172 | """ 173 | policy = [ 174 | [ 175 | POLICY_TUPLE('Color', 0.0, 6), 176 | POLICY_TUPLE('Cutout', 0.6, 8), 177 | POLICY_TUPLE('Sharpness', 0.4, 8) 178 | ], 179 | [ 180 | POLICY_TUPLE('Rotate', 0.4, 8), 181 | POLICY_TUPLE('Sharpness', 0.4, 2), 182 | POLICY_TUPLE('Rotate', 0.8, 10) 183 | ], 184 | [ 185 | POLICY_TUPLE('Translate_Y', 1.0, 8), 186 | POLICY_TUPLE('Auto_Contrast', 0.8, 2) 187 | ], 188 | [ 189 | POLICY_TUPLE('Auto_Contrast', 0.4, 6), 190 | POLICY_TUPLE('Shear_X', 0.8, 8), 191 | POLICY_TUPLE('Brightness', 0.0, 10) 192 | ], 193 | [ 194 | POLICY_TUPLE('Solarize_Add', 0.2, 6), 195 | POLICY_TUPLE('Contrast', 0.0, 10), 196 | POLICY_TUPLE('Auto_Contrast', 0.6, 0) 197 | ], 198 | [ 199 | POLICY_TUPLE('Cutout', 0.2, 0), 200 | POLICY_TUPLE('Solarize', 0.8, 8), 201 | POLICY_TUPLE('Color', 1.0, 4) 202 | ], 203 | [ 204 | POLICY_TUPLE('Translate_Y', 0.0, 4), 205 | POLICY_TUPLE('Equalize', 0.6, 8), 206 | POLICY_TUPLE('Solarize', 0.0, 10) 207 | ], 208 | [ 209 | POLICY_TUPLE('Translate_Y', 0.2, 2), 210 | POLICY_TUPLE('Shear_Y', 0.8, 8), 211 | POLICY_TUPLE('Rotate', 0.8, 8) 212 | ], 213 | [ 214 | POLICY_TUPLE('Cutout', 0.8, 8), 215 | POLICY_TUPLE('Brightness', 0.8, 8), 216 | POLICY_TUPLE('Cutout', 0.2, 2) 217 | ], 218 | [ 219 | POLICY_TUPLE('Color', 0.8, 4), 220 | POLICY_TUPLE('Translate_Y', 1.0, 6), 221 | POLICY_TUPLE('Rotate', 0.6, 6) 222 | ], 223 | [ 224 | POLICY_TUPLE('Rotate', 0.6, 10), 225 | POLICY_TUPLE('Cutout_Fraction', 1.0, 4), 226 | POLICY_TUPLE('Cutout', 0.2, 8) 227 | ], 228 | [ 229 | POLICY_TUPLE('Rotate', 0.0, 0), 230 | POLICY_TUPLE('Equalize', 0.6, 6), 231 | POLICY_TUPLE('Shear_Y', 0.6, 8) 232 | ], 233 | [ 234 | POLICY_TUPLE('Brightness', 0.8, 8), 235 | POLICY_TUPLE('Auto_Contrast', 0.4, 2), 236 | POLICY_TUPLE('Brightness', 0.2, 2) 237 | ], 238 | [ 239 | POLICY_TUPLE('Translate_Y', 0.4, 8), 240 | POLICY_TUPLE('Solarize', 0.4, 6), 241 | POLICY_TUPLE('Solarize_Add', 0.2, 10) 242 | ], 243 | [ 244 | POLICY_TUPLE('Contrast', 1.0, 10), 245 | POLICY_TUPLE('Solarize_Add', 0.2, 8), 246 | POLICY_TUPLE('Equalize', 0.2, 4) 247 | ], 248 | ] 249 | return policy 250 | 251 | 252 | def policies_v3() -> List[List[POLICY_TUPLE_TYPE]]: 253 | """ 254 | Version 3 of augmentation policies 255 | ​ 256 | :rtype: List[List[POLICY_TUPLE_TYPE]] 257 | :return: List of policies 258 | """ 259 | policy = [ 260 | [ 261 | POLICY_TUPLE('Posterize', 0.8, 2), 262 | POLICY_TUPLE('Translate_X', 1.0, 8) 263 | ], 264 | [ 265 | POLICY_TUPLE('Cutout_Fraction', 0.2, 10), 266 | POLICY_TUPLE('Sharpness', 1.0, 8) 267 | ], 268 | [ 269 | POLICY_TUPLE('Rotate', 0.6, 8), 270 | POLICY_TUPLE('Rotate', 0.8, 10) 271 | ], 272 | [ 273 | POLICY_TUPLE('Equalize', 0.8, 10), 274 | POLICY_TUPLE('Auto_Contrast', 0.2, 10) 275 | ], 276 | [ 277 | POLICY_TUPLE('Solarize_Add', 0.2, 2), 278 | POLICY_TUPLE('Translate_Y', 0.2, 8) 279 | ], 280 | [ 281 | POLICY_TUPLE('Sharpness', 0.0, 2), 282 | POLICY_TUPLE('Color', 0.4, 8) 283 | ], 284 | [ 285 | POLICY_TUPLE('Equalize', 1.0, 8), 286 | POLICY_TUPLE('Translate_Y', 1.0, 8) 287 | ], 288 | [ 289 | POLICY_TUPLE('Posterize', 0.6, 2), 290 | POLICY_TUPLE('Rotate', 0.0, 10) 291 | ], 292 | [ 293 | POLICY_TUPLE('Auto_Contrast', 0.6, 0), 294 | POLICY_TUPLE('Rotate', 1.0, 6) 295 | ], 296 | [ 297 | POLICY_TUPLE('Equalize', 0.0, 4), 298 | POLICY_TUPLE('Cutout', 0.8, 10) 299 | ], 300 | [ 301 | POLICY_TUPLE('Brightness', 1.0, 2), 302 | POLICY_TUPLE('Translate_Y', 1.0, 6) 303 | ], 304 | [ 305 | POLICY_TUPLE('Contrast', 0.0, 2), 306 | POLICY_TUPLE('Shear_Y', 0.8, 0) 307 | ], 308 | [ 309 | POLICY_TUPLE('Auto_Contrast', 0.8, 10), 310 | POLICY_TUPLE('Contrast', 0.2, 10) 311 | ], 312 | [ 313 | POLICY_TUPLE('Rotate', 1.0, 10), 314 | POLICY_TUPLE('Cutout', 1.0, 10) 315 | ], 316 | [ 317 | POLICY_TUPLE('Solarize_Add', 0.8, 6), 318 | POLICY_TUPLE('Equalize', 0.8, 8) 319 | ], 320 | ] 321 | return policy 322 | 323 | 324 | def list_policies() -> List: 325 | """ 326 | Returns a list of policies available 327 | 328 | :rtype: List 329 | :return: List of available policies 330 | """ 331 | return [ 332 | policies_v0.__name__, 333 | policies_v1.__name__, 334 | policies_v2.__name__, 335 | policies_v3.__name__, 336 | ] 337 | 338 | 339 | class PolicyContainer: 340 | """ 341 | Policy container for all the policies available during augmentation 342 | """ 343 | 344 | def __init__( 345 | self, 346 | policy_set: List[List[POLICY_TUPLE_TYPE]], 347 | name_to_augmentation: Dict[str, Callable] = NAME_TO_AUGMENTATION, 348 | return_yolo: bool = False, 349 | random_state: Union[None, int] = None, 350 | ): 351 | """ 352 | Policy container initialisation 353 | 354 | :type policy_list: List[List[POLICY_TUPLE_TYPE]] 355 | :param policy_list: List of policies available for augmentation 356 | :type name_to_augmentation: Dict[str, Callable] 357 | :param name_to_augmentation: Mapping of augmentation name to function 358 | reference 359 | :type return_yolo: bool 360 | :param return_yolo: Flag for returning the bounding boxes in YOLO 361 | format 362 | :type random_state: Union[None, int] 363 | :param random_state: Provide a random state for reproducibility 364 | """ 365 | self.policies = policy_set 366 | self.augmentations = name_to_augmentation 367 | self.return_yolo = return_yolo 368 | if random_state is not None: 369 | random.seed(random_state) 370 | np.random.seed(random_state) 371 | 372 | def __getitem__(self, item: str) -> Callable: 373 | """ 374 | Returns the augmentation method reference 375 | 376 | :type item: str 377 | :param item: Name of augmentation method 378 | :rtype Callable 379 | :return: Augmentation method 380 | """ 381 | return self.augmentations[item] 382 | 383 | def _bbs_to_percent( 384 | self, 385 | bounding_boxes: List[BoundingBox], 386 | image_height: int, 387 | image_width: int, 388 | ) -> np.array: 389 | """ 390 | Convert the augmented bounding boxes to YOLO format: 391 | [x_centre, y_centre, box_width, box_height] 392 | 393 | :type bounding_boxes: List[BoundingBox] 394 | :param bounding_boxes: list of augmented bounding boxes 395 | :type image_height: int 396 | :param image_height: Height of the image 397 | :type image_width: int 398 | :param image_width: Width of the image 399 | :rtype: np.array 400 | :return: Numpy array of augmented bounding boxes 401 | """ 402 | return np.array([ 403 | [ 404 | bb.label, 405 | bb.center_x / image_width, 406 | bb.center_y / image_height, 407 | bb.width / image_width, 408 | bb.height / image_height 409 | ] 410 | for bb in bounding_boxes 411 | ]) 412 | 413 | def _bbs_to_pixel(self, bounding_boxes: List[BoundingBox]) -> np.array: 414 | """ 415 | Return the augmented bounding boxes in pixel format: 416 | [x_min, y_min, x_max, y_max] 417 | 418 | :type bounding_boxes: List[BoundingBox] 419 | :param bounding_boxes: 420 | :rtype: np.array 421 | :return: Numpy array of augmented bounding boxes 422 | """ 423 | return np.array([ 424 | [ 425 | bb.label, 426 | bb.x1, 427 | bb.y1, 428 | bb.x2, 429 | bb.y2 430 | ] 431 | for bb in bounding_boxes 432 | ]).astype('int32') 433 | 434 | def _cutout_kwargs(self, image_shape: Tuple[int, int]) -> Dict[str, int]: 435 | """ 436 | Returns the kwargs for cutout augmentations 437 | 438 | :type image_shape: Tuple[int, int] 439 | :param image_shape: Shape of the image 440 | :rtype: Dict[str, int] 441 | :return: Kwargs for cutout augmentations 442 | """ 443 | return { 444 | 'height': image_shape[0], 445 | 'width': image_shape[1] 446 | } 447 | 448 | def select_random_policy(self) -> List[POLICY_TUPLE]: 449 | """ 450 | Selects a random policy from the list of available policies 451 | 452 | :rtype: List[POLICY_TUPLE] 453 | :return: Randomly selected policy 454 | """ 455 | return random.choice(self.policies) 456 | 457 | def apply_augmentation( 458 | self, 459 | policy: List[POLICY_TUPLE], 460 | image: np.array, 461 | bounding_boxes: List[List[int]], 462 | labels: Iterable[int], 463 | ) -> Tuple[np.array, np.array]: 464 | """ 465 | Applies the augmentations to the image. 466 | 467 | :type policy: List[POLICY_TUPLE] 468 | :param policy: Augmentation policy to apply to the image 469 | :type image: np.array 470 | :param image: Image to augment 471 | :type bounding_boxes: List[List[int]] 472 | :param bounding_boxes: Bounding boxes for the image in the format: 473 | [x_min, y_min, x_max, y_max] 474 | :type labels: Iterable[int] 475 | :param labels: Iterable containing class labels as integers 476 | :rtype: Tuple[np.array, np.array] 477 | :return: Tuple containing the augmented image and bounding boxes 478 | """ 479 | bbs = BoundingBoxesOnImage( 480 | [ 481 | BoundingBox(*bb, label=label) 482 | for bb, label in zip(bounding_boxes, labels) 483 | ], 484 | image.shape 485 | ) 486 | for i in policy: 487 | if i.name.endswith('BBox'): 488 | new_bbs = [] 489 | for box in bbs: 490 | if i.probability > np.random.random(): 491 | if i.name == 'Cutout_BBox': 492 | kwargs = self._cutout_kwargs(image.shape) 493 | aug = self[i.name](i.magnitude, **kwargs) 494 | else: 495 | aug = self[i.name](i.magnitude) 496 | image, box_aug = aug( 497 | image=image, 498 | bounding_boxes=BoundingBoxesOnImage([box], image.shape) # noqa: E501 499 | ) 500 | new_bbs.append(box_aug[0]) 501 | else: 502 | new_bbs.append(box) 503 | bbs = BoundingBoxesOnImage(new_bbs, image.shape) 504 | elif i.probability > np.random.random(): 505 | if i.name == 'Cutout': 506 | kwargs = self._cutout_kwargs(image.shape) 507 | aug = self[i.name](i.magnitude, **kwargs) 508 | elif i.name == 'Cutout_Fraction': 509 | if len(bbs) == 0: 510 | aug = self[i.name](i.magnitude) 511 | else: 512 | random_bb = np.random.choice(bbs) 513 | kwargs = { 514 | 'height_bbox': random_bb.height, 515 | 'width_bbox': random_bb.width, 516 | 'height': image.shape[0], 517 | 'width': image.shape[1] 518 | } 519 | aug = self[i.name](i.magnitude, **kwargs) 520 | else: 521 | aug = self[i.name](i.magnitude) 522 | image, bbs = aug(image=image, bounding_boxes=bbs) 523 | bbs = bbs.remove_out_of_image().clip_out_of_image() 524 | if self.return_yolo: 525 | bbs = self._bbs_to_percent(bbs, image.shape[0], image.shape[1]) 526 | else: 527 | bbs = self._bbs_to_pixel(bbs) 528 | return image, bbs 529 | -------------------------------------------------------------------------------- /bbaug/visuals/__init__.py: -------------------------------------------------------------------------------- 1 | from .visuals import * # noqa 2 | -------------------------------------------------------------------------------- /bbaug/visuals/visuals.py: -------------------------------------------------------------------------------- 1 | """ Module to visualise policies """ 2 | 3 | from typing import ( 4 | Callable, 5 | Dict, 6 | Iterable, 7 | List 8 | ) 9 | 10 | import imageio 11 | from imgaug.augmentables.bbs import ( 12 | BoundingBox, 13 | BoundingBoxesOnImage, 14 | ) 15 | import matplotlib.pyplot as plt 16 | 17 | from bbaug.augmentations import NAME_TO_AUGMENTATION 18 | from bbaug.policies import POLICY_TUPLE_TYPE, PolicyContainer 19 | 20 | __all__ = [ 21 | 'visualise_policy' 22 | ] 23 | 24 | 25 | def visualise_policy( 26 | image_path: str, 27 | save_path: str, 28 | bounding_boxes: List[List[int]], 29 | labels: Iterable[int], 30 | policy: List[List[POLICY_TUPLE_TYPE]], 31 | name_to_augmentation: Dict[str, Callable] = NAME_TO_AUGMENTATION 32 | ) -> None: 33 | """ 34 | Visualise a single policy on an image 35 | 36 | :type image_path: str 37 | :param image_path: Path of the image 38 | :type save_path: str 39 | :param save_path: Directory where to save the images 40 | :type bounding_boxes: List[List[int]] 41 | :param bounding_boxes: Bounding boxes for the image 42 | :type labels: Iterable[int] 43 | :param labels: Iterable containing class labels as integers 44 | :type policy: List[List[POLICY_TUPLE_TYPE]] 45 | :param policy: The policy set to apply to the image 46 | :type name_to_augmentation: Dict[str, Callable] 47 | :param name_to_augmentation: Dictionary mapping of the augmentation name 48 | to the augmentation method 49 | :rtype: None 50 | """ 51 | policy_container = PolicyContainer( 52 | policy, 53 | name_to_augmentation=name_to_augmentation 54 | ) 55 | 56 | for i, pol in enumerate(policy): 57 | fig, axes = plt.subplots(1, 3, figsize=(15, 4)) 58 | image = imageio.imread(image_path) 59 | [ax.axis('off') for ax in axes] 60 | for ax in range(3): 61 | img_aug, bbs_aug = policy_container.apply_augmentation( 62 | pol, 63 | image, 64 | bounding_boxes, 65 | labels, 66 | ) 67 | bbs_aug = BoundingBoxesOnImage([ 68 | BoundingBox(*box[1:], label=box[0]) 69 | for box in bbs_aug 70 | ], shape=image.shape) 71 | axes[ax].imshow(bbs_aug.draw_on_image(img_aug, size=2)) 72 | fig.suptitle(pol) 73 | fig.tight_layout() 74 | fig.savefig(f'{save_path}/sub_policy_{i}.png') 75 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/conftest.py -------------------------------------------------------------------------------- /coverage.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | coverage 17 | coverage 18 | 100% 19 | 100% 20 | 21 | 22 | -------------------------------------------------------------------------------- /notebooks/data/example_dataset/boxes/dog_1.txt: -------------------------------------------------------------------------------- 1 | 0 42 71 266 244 2 | 1 279 48 341 71 -------------------------------------------------------------------------------- /notebooks/data/example_dataset/images/dog_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harpalsahota/bbaug/8beca7b3444ec693edc7a2acc279ae67f24e3dfe/notebooks/data/example_dataset/images/dog_1.jpg -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "bbaug" 3 | version = "0.4.2" 4 | description = "Bounding box augmentations for Pytorch" 5 | authors = ["Harpal Sahota"] 6 | 7 | [tool.poetry.dependencies] 8 | python = "^3.6" 9 | torch = "^1.4.0" 10 | torchvision = "^0.5.0" 11 | imgaug = "^0.4.0" 12 | # For Windows Users 13 | #torch = {url = "https://download.pytorch.org/whl/cu101/torch-1.4.0-cp37-cp37m-win_amd64.whl"} 14 | #torchvision = {url = "https://download.pytorch.org/whl/cu101/torchvision-0.5.0-cp37-cp37m-win_amd64.whl"} 15 | 16 | [tool.poetry.dev-dependencies] 17 | pytest = "^5.4.1" 18 | flake8 = "^3.7.9" 19 | jupyterlab = "^2.0.1" 20 | bumpversion = "^0.5.3" 21 | bump2version = "^1.0.0" 22 | pytest-mock = "^2.0.0" 23 | pytest-cov = "^2.8.1" 24 | coverage-badge = "^1.0.1" 25 | 26 | [build-system] 27 | requires = ["poetry>=0.12"] 28 | build-backend = "poetry.masonry.api" 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | description = "Bounding box augmentation for PyTorch with Google's Brain Team augmentation policies" 4 | 5 | with open('README.md', encoding='utf-8') as f: 6 | long_description = '\n' + f.read() 7 | 8 | setup( 9 | name='bbaug', 10 | version='0.4.2', 11 | author='Harpal Sahota', 12 | author_email='harpal28sahota@gmail.com', 13 | description=description, 14 | long_description=long_description, 15 | long_description_content_type='text/markdown', 16 | url='https://github.com/harpalsahota/bbaug', 17 | packages=[ 18 | 'bbaug', 19 | 'bbaug.augmentations', 20 | 'bbaug.policies', 21 | 'bbaug.visuals', 22 | ], 23 | python_requires='>=3.5', 24 | install_requires=[ 25 | 'imgaug', 26 | ], 27 | classifiers=[ 28 | 'Development Status :: 4 - Beta', 29 | 'License :: OSI Approved :: MIT License', 30 | 'Operating System :: OS Independent', 31 | 'Programming Language :: Python', 32 | 'Programming Language :: Python :: 3', 33 | 'Programming Language :: Python :: 3.5', 34 | 'Programming Language :: Python :: 3.6', 35 | 'Programming Language :: Python :: 3.7', 36 | 'Programming Language :: Python :: 3.8', 37 | 'Programming Language :: Python :: 3 :: Only', 38 | 'Topic :: Scientific/Engineering :: Image Recognition', 39 | ] 40 | ) -------------------------------------------------------------------------------- /tests/test_augmentations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from bbaug.augmentations import augmentations 5 | from bbaug.exceptions import InvalidMagnitude 6 | 7 | 8 | class TestNegate: 9 | 10 | def test_negate_positive(self, mocker): 11 | numpy_random_mock = mocker.patch('bbaug.augmentations.augmentations.np.random.random') 12 | numpy_random_mock.return_value = 0.5 13 | 14 | @augmentations.negate 15 | def dumb_return(): 16 | return 1 17 | assert dumb_return() > 0 18 | 19 | numpy_random_mock.return_value = 0.5 20 | assert dumb_return() > 0 21 | 22 | def test_negate_negative(self, mocker): 23 | numpy_random_mock = mocker.patch('bbaug.augmentations.augmentations.np.random.random') 24 | numpy_random_mock.return_value = 0.49 25 | 26 | @augmentations.negate 27 | def dumb_return(): 28 | return 1 29 | 30 | assert dumb_return() < 0 31 | 32 | 33 | def test_validate_magnitude(): 34 | 35 | @augmentations.validate_magnitude 36 | def dumb_func(magnitude): 37 | return magnitude 38 | 39 | with pytest.raises(InvalidMagnitude): 40 | dumb_func(11) 41 | 42 | with pytest.raises(InvalidMagnitude): 43 | dumb_func(-0.1) 44 | 45 | dumb_func(0) 46 | dumb_func(5) 47 | dumb_func(10) 48 | 49 | 50 | def test__img_enhance_to_arg(): 51 | 52 | res = augmentations._img_enhance_to_arg(2) 53 | assert res == pytest.approx(0.46) 54 | 55 | res = augmentations._img_enhance_to_arg(8) 56 | assert res == pytest.approx(1.54) 57 | 58 | 59 | def test__rotate_mag_to_arg(mocker): 60 | 61 | numpy_random_mock = mocker.patch('bbaug.augmentations.augmentations.np.random.random') 62 | numpy_random_mock.return_value = 0.5 63 | 64 | res = augmentations._rotate_mag_to_arg(8) 65 | assert res == pytest.approx(24.0) 66 | 67 | res = augmentations._rotate_mag_to_arg(10) 68 | assert res == pytest.approx(30.0) 69 | 70 | 71 | def test__shear_mag_to_arg(mocker): 72 | 73 | numpy_random_mock = mocker.patch('bbaug.augmentations.augmentations.np.random.random') 74 | numpy_random_mock.return_value = 0.5 75 | 76 | res = augmentations._shear_mag_to_arg(0) 77 | assert res == pytest.approx(0.0) 78 | 79 | res = augmentations._shear_mag_to_arg(7) 80 | assert res == pytest.approx(0.21) 81 | 82 | 83 | def test__translate_mag_to_arg(mocker): 84 | 85 | numpy_random_mock = mocker.patch('bbaug.augmentations.augmentations.np.random.random') 86 | numpy_random_mock.return_value = 0.5 87 | 88 | res = augmentations._translate_mag_to_arg(8) 89 | assert res == 200 90 | 91 | res = augmentations._translate_mag_to_arg(2) 92 | assert res == 50 93 | 94 | 95 | def test_auto_contrast(mocker): 96 | 97 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.pillike.Autocontrast') 98 | 99 | augmentations.auto_contrast(10) 100 | assert aug_mock.called 101 | aug_mock.assert_called_with(0) 102 | 103 | 104 | def test_brightness(mocker): 105 | 106 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.pillike.EnhanceBrightness') 107 | 108 | augmentations.brightness(10) 109 | assert aug_mock.called 110 | aug_mock.assert_called_with(pytest.approx(1.9)) 111 | 112 | aug_mock.reset_mock() 113 | augmentations.brightness(3) 114 | aug_mock.assert_called_with(pytest.approx(0.64)) 115 | 116 | 117 | def test_colour(mocker): 118 | 119 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.pillike.EnhanceColor') 120 | 121 | augmentations.colour(6) 122 | assert aug_mock.called 123 | aug_mock.assert_called_with(pytest.approx(1.18)) 124 | 125 | aug_mock.reset_mock() 126 | augmentations.colour(7) 127 | aug_mock.assert_called_with(pytest.approx(1.36)) 128 | 129 | 130 | def test_contrast(mocker): 131 | 132 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.GammaContrast') 133 | 134 | augmentations.contrast(1) 135 | assert aug_mock.called 136 | aug_mock.assert_called_with(pytest.approx(0.28)) 137 | 138 | aug_mock.reset_mock() 139 | augmentations.contrast(0) 140 | aug_mock.assert_called_with(pytest.approx(0.1)) 141 | 142 | 143 | def test_cutout(mocker): 144 | 145 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.Cutout') 146 | 147 | augmentations.cutout(2) 148 | assert aug_mock.called 149 | aug_mock.assert_called_with() 150 | 151 | aug_mock.reset_mock() 152 | augmentations.cutout(5, height=1000, width=1000) 153 | args, kwargs = aug_mock.call_args_list[0] 154 | assert {'size': pytest.approx((0.1, 0.1))} == kwargs 155 | 156 | aug_mock.reset_mock() 157 | augmentations.cutout(10, height=10, width=10) 158 | args, kwargs = aug_mock.call_args_list[0] 159 | assert {'size': pytest.approx((1.0, 1.0))} == kwargs 160 | 161 | 162 | def test_cutout_bbox(mocker): 163 | 164 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.BlendAlphaBoundingBoxes') 165 | aug_cutout_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.Cutout') 166 | 167 | augmentations.cutout_bbox(10) 168 | args, kwargs = aug_mock.call_args_list[0] 169 | assert tuple([None]) == args 170 | assert 'foreground' in kwargs 171 | aug_cutout_mock.assert_called_with() 172 | 173 | aug_mock.reset_mock() 174 | aug_cutout_mock.reset_mock() 175 | augmentations.cutout_bbox(10, height=500, width=500) 176 | args, kwargs = aug_mock.call_args_list[0] 177 | assert tuple([None]) == args 178 | assert 'foreground' in kwargs 179 | args, kwargs = aug_cutout_mock.call_args_list[0] 180 | assert tuple() == args 181 | assert 'size' in kwargs 182 | assert {'size': pytest.approx((0.1, 0.1))} == kwargs 183 | aug_cutout_mock.assert_called_with(size=pytest.approx((0.1, 0.1))) 184 | 185 | aug_mock.reset_mock() 186 | aug_cutout_mock.reset_mock() 187 | augmentations.cutout_bbox(10, height=250, width=5) 188 | args, kwargs = aug_mock.call_args_list[0] 189 | assert tuple([None]) == args 190 | assert 'foreground' in kwargs 191 | args, kwargs = aug_cutout_mock.call_args_list[0] 192 | assert tuple() == args 193 | assert 'size' in kwargs 194 | assert {'size': pytest.approx((0.2, 1.0))} == kwargs 195 | aug_cutout_mock.assert_called_with(size=pytest.approx((0.2, 1.0))) 196 | 197 | 198 | def test_cutout_fraction(mocker): 199 | 200 | aug_mock_cutout = mocker.patch('bbaug.augmentations.augmentations.iaa.Cutout') 201 | 202 | augmentations.cutout_fraction(7) 203 | assert aug_mock_cutout.called 204 | aug_mock_cutout.assert_called_with() 205 | 206 | aug_mock_cutout.reset_mock() 207 | augmentations.cutout_fraction(9, height=1000, width=1000) 208 | aug_mock_cutout.assert_called_with() 209 | 210 | aug_mock_cutout.reset_mock() 211 | augmentations.cutout_fraction(10, height=1000, width=1000, height_bbox=100, width_bbox=100) 212 | args, kwargs = aug_mock_cutout.call_args_list[0] 213 | assert tuple() == args 214 | assert 'size' in kwargs 215 | assert {'size': pytest.approx((0.075, 0.075))} == kwargs 216 | 217 | 218 | def test_equalise(mocker): 219 | 220 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.AllChannelsHistogramEqualization') 221 | 222 | augmentations.equalise(1) 223 | assert aug_mock.called 224 | aug_mock.assert_called_with() 225 | 226 | 227 | def test_fliplr(mocker): 228 | 229 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.BlendAlphaBoundingBoxes') 230 | aug_fliplr_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.Fliplr') 231 | 232 | augmentations.fliplr_boxes(1) 233 | args, kwargs = aug_mock.call_args_list[0] 234 | assert tuple([None]) == args 235 | assert 'foreground' in kwargs 236 | aug_fliplr_mock.assert_called_with(pytest.approx(1.0)) 237 | 238 | 239 | def test_posterize(mocker): 240 | 241 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.color.Posterize') 242 | 243 | augmentations.posterize(1) 244 | assert aug_mock.called 245 | args, kwargs = aug_mock.call_args_list[0] 246 | assert {'nb_bits': 1} == kwargs 247 | assert tuple() == args 248 | 249 | aug_mock.reset_mock() 250 | augmentations.posterize(9) 251 | args, kwargs = aug_mock.call_args_list[0] 252 | assert {'nb_bits': 3} == kwargs 253 | assert tuple() == args 254 | 255 | 256 | def test_rotate(mocker): 257 | 258 | numpy_random_mock = mocker.patch('bbaug.augmentations.augmentations.np.random.random') 259 | numpy_random_mock.return_value = 0.5 260 | aug_mock_rotate = mocker.patch('bbaug.augmentations.augmentations.iaa.Rotate') 261 | 262 | augmentations.rotate(8) 263 | assert aug_mock_rotate.called 264 | aug_mock_rotate.assert_called_with(pytest.approx(24.0)) 265 | 266 | aug_mock_rotate.reset_mock() 267 | augmentations.rotate(10) 268 | aug_mock_rotate.assert_called_with(pytest.approx(30.0)) 269 | 270 | 271 | def test_sharpness(mocker): 272 | 273 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.pillike.EnhanceSharpness') 274 | 275 | augmentations.sharpness(4) 276 | assert aug_mock.called 277 | aug_mock.assert_called_with(pytest.approx(0.82)) 278 | 279 | 280 | def test_shear_x(mocker): 281 | 282 | numpy_random_mock = mocker.patch('bbaug.augmentations.augmentations.np.random.random') 283 | numpy_random_mock.return_value = 0.49 284 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.ShearX') 285 | 286 | augmentations.shear_x(1) 287 | assert aug_mock.called 288 | aug_mock.assert_called_with(pytest.approx(-0.03)) 289 | 290 | 291 | def test_shear_x_bbox(mocker): 292 | numpy_random_mock = mocker.patch('bbaug.augmentations.augmentations.np.random.random') 293 | numpy_random_mock.return_value = 0.0 294 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.BlendAlphaBoundingBoxes') 295 | aug_shear_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.ShearX') 296 | 297 | augmentations.shear_x_bbox(7) 298 | args, kwargs = aug_mock.call_args_list[0] 299 | assert tuple([None]) == args 300 | assert 'foreground' in kwargs 301 | aug_shear_mock.assert_called_with(pytest.approx(-0.21)) 302 | 303 | 304 | def test_shear_y(mocker): 305 | 306 | numpy_random_mock = mocker.patch('bbaug.augmentations.augmentations.np.random.random') 307 | numpy_random_mock.return_value = 0.32 308 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.ShearY') 309 | 310 | augmentations.shear_y(3) 311 | assert aug_mock.called 312 | aug_mock.assert_called_with(pytest.approx(-0.09)) 313 | 314 | 315 | def test_shear_y_bbox(mocker): 316 | 317 | numpy_random_mock = mocker.patch('bbaug.augmentations.augmentations.np.random.random') 318 | numpy_random_mock.return_value = 0.5 319 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.ShearY') 320 | 321 | augmentations.shear_y_bbox(2) 322 | assert aug_mock.called 323 | aug_mock.assert_called_with(pytest.approx(0.06)) 324 | 325 | 326 | def test_solarize(mocker): 327 | 328 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.pillike.Solarize') 329 | 330 | augmentations.solarize(50) 331 | assert aug_mock.called 332 | aug_mock.assert_called_with(threshold=128) 333 | 334 | 335 | def test_solarize_add(): 336 | 337 | aug = augmentations.solarize_add(2) 338 | img_aug, bbs_aug = aug(np.zeros((2, 2)).astype('uint8'), []) 339 | exp = np.array([22]*4).reshape(2, 2) 340 | assert np.array_equal(exp, img_aug) 341 | 342 | aug = augmentations.solarize_add(5) 343 | img_aug, bbs_aug = aug(np.array([[22, 255], [0, 128]]).astype('uint8'), []) 344 | exp = np.array([[77, 255], [55, 128]]) 345 | assert np.array_equal(exp, img_aug) 346 | 347 | 348 | def test_translate_x(mocker): 349 | 350 | numpy_random_mock = mocker.patch('bbaug.augmentations.augmentations.np.random.random') 351 | numpy_random_mock.return_value = 0.72 352 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.geometric.TranslateX') 353 | 354 | augmentations.translate_x(8) 355 | assert aug_mock.called 356 | aug_mock.assert_called_with(px=200) 357 | 358 | 359 | def test_translate_x_bbox(mocker): 360 | 361 | numpy_random_mock = mocker.patch('bbaug.augmentations.augmentations.np.random.random') 362 | numpy_random_mock.return_value = 0.5 363 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.BlendAlphaBoundingBoxes') 364 | aug_mock_translate = mocker.patch('bbaug.augmentations.augmentations.iaa.geometric.TranslateX') 365 | 366 | augmentations.translate_x_bbox(10) 367 | assert aug_mock.called 368 | args, kwargs = aug_mock.call_args_list[0] 369 | assert tuple([None]) == args 370 | assert 'foreground' in kwargs 371 | args, kwargs = aug_mock_translate.call_args_list[0] 372 | assert tuple() == args 373 | assert {'px': 120} == kwargs 374 | 375 | aug_mock.reset_mock() 376 | aug_mock_translate.reset_mock() 377 | augmentations.translate_x_bbox(3) 378 | assert aug_mock.called 379 | args, kwargs = aug_mock.call_args_list[0] 380 | assert tuple([None]) == args 381 | assert 'foreground' in kwargs 382 | args, kwargs = aug_mock_translate.call_args_list[0] 383 | assert tuple() == args 384 | assert {'px': 36} == kwargs 385 | 386 | 387 | def test_translate_y(mocker): 388 | 389 | numpy_random_mock = mocker.patch('bbaug.augmentations.augmentations.np.random.random') 390 | numpy_random_mock.return_value = 0.22 391 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.geometric.TranslateY') 392 | 393 | augmentations.translate_y(7) 394 | assert aug_mock.called 395 | aug_mock.assert_called_with(px=-175) 396 | 397 | 398 | def test_translate_y_bbox(mocker): 399 | 400 | numpy_random_mock = mocker.patch('bbaug.augmentations.augmentations.np.random.random') 401 | numpy_random_mock.return_value = 0.5 402 | aug_mock = mocker.patch('bbaug.augmentations.augmentations.iaa.BlendAlphaBoundingBoxes') 403 | aug_mock_translate = mocker.patch('bbaug.augmentations.augmentations.iaa.geometric.TranslateY') 404 | 405 | augmentations.translate_y_bbox(0) 406 | assert aug_mock.called 407 | args, kwargs = aug_mock.call_args_list[0] 408 | assert tuple([None]) == args 409 | assert 'foreground' in kwargs 410 | args, kwargs = aug_mock_translate.call_args_list[0] 411 | assert tuple() == args 412 | assert {'px': 0} == kwargs 413 | 414 | aug_mock.reset_mock() 415 | aug_mock_translate.reset_mock() 416 | augmentations.translate_y_bbox(1) 417 | assert aug_mock.called 418 | args, kwargs = aug_mock.call_args_list[0] 419 | assert tuple([None]) == args 420 | assert 'foreground' in kwargs 421 | args, kwargs = aug_mock_translate.call_args_list[0] 422 | assert tuple() == args 423 | assert {'px': 12} == kwargs 424 | -------------------------------------------------------------------------------- /tests/test_policies.py: -------------------------------------------------------------------------------- 1 | from imgaug.augmentables.bbs import ( 2 | BoundingBox, 3 | BoundingBoxesOnImage, 4 | ) 5 | import numpy as np 6 | import pytest 7 | 8 | from bbaug.policies import policies 9 | 10 | 11 | def test_list_policies(): 12 | res = policies.list_policies() 13 | assert len(res) == 4 14 | assert 'policies_v0' in res 15 | assert 'policies_v1' in res 16 | assert 'policies_v2' in res 17 | assert 'policies_v3' in res 18 | 19 | 20 | def test_policies_v0(): 21 | 22 | v0_policies = policies.policies_v0() 23 | assert len(v0_policies) == 5 24 | 25 | for policy in v0_policies: 26 | assert len(policy) == 2 27 | for sub_policy in policy: 28 | assert isinstance(sub_policy, policies.POLICY_TUPLE) 29 | assert type(sub_policy.name) is str 30 | assert type(sub_policy.probability) is float 31 | assert type(sub_policy.magnitude) is int 32 | assert sub_policy.name in policies.NAME_TO_AUGMENTATION 33 | assert 0.0 <= sub_policy.probability <= 1.0 34 | assert 0 <= sub_policy.magnitude <= 10 35 | 36 | 37 | def test_policies_v1(): 38 | 39 | v1_policies = policies.policies_v1() 40 | assert len(v1_policies) == 20 41 | 42 | for policy in v1_policies: 43 | assert len(policy) == 2 44 | for sub_policy in policy: 45 | assert isinstance(sub_policy, policies.POLICY_TUPLE) 46 | assert type(sub_policy.name) is str 47 | assert type(sub_policy.probability) is float 48 | assert type(sub_policy.magnitude) is int 49 | assert sub_policy.name in policies.NAME_TO_AUGMENTATION 50 | assert 0.0 <= sub_policy.probability <= 1.0 51 | assert 0 <= sub_policy.magnitude <= 10 52 | 53 | 54 | def test_policies_v2(): 55 | 56 | v2_policies = policies.policies_v2() 57 | assert len(v2_policies) == 15 58 | 59 | for policy in v2_policies: 60 | assert (len(policy) == 2 or len(policy) == 3) 61 | for sub_policy in policy: 62 | assert isinstance(sub_policy, policies.POLICY_TUPLE) 63 | assert type(sub_policy.name) is str 64 | assert type(sub_policy.probability) is float 65 | assert type(sub_policy.magnitude) is int 66 | assert sub_policy.name in policies.NAME_TO_AUGMENTATION 67 | assert 0.0 <= sub_policy.probability <= 1.0 68 | assert 0 <= sub_policy.magnitude <= 10 69 | 70 | 71 | def test_policies_v3(): 72 | 73 | v3_policies = policies.policies_v3() 74 | assert len(v3_policies) == 15 75 | 76 | for policy in v3_policies: 77 | assert len(policy) == 2 78 | for sub_policy in policy: 79 | assert isinstance(sub_policy, policies.POLICY_TUPLE) 80 | assert type(sub_policy.name) is str 81 | assert type(sub_policy.probability) is float 82 | assert type(sub_policy.magnitude) is int 83 | assert sub_policy.name in policies.NAME_TO_AUGMENTATION 84 | assert 0.0 <= sub_policy.probability <= 1.0 85 | assert 0 <= sub_policy.magnitude <= 10 86 | 87 | 88 | class TestPolicyContainer: 89 | 90 | def test__init__(self, mocker): 91 | random_mocker = mocker.patch('bbaug.policies.policies.random.seed') 92 | numpy_random_mocker = mocker.patch('bbaug.policies.policies.np.random.seed') 93 | 94 | p = policies.PolicyContainer(policies.policies_v3()) 95 | assert not random_mocker.called 96 | assert not numpy_random_mocker.called 97 | 98 | p = policies.PolicyContainer(policies.policies_v3(), random_state=42) 99 | random_mocker.assert_called_with(42) 100 | numpy_random_mocker.assert_called_with(42) 101 | 102 | def test___get__item(self): 103 | p = policies.PolicyContainer(policies.policies_v3()) 104 | assert p['Color'].__name__ == 'colour' 105 | 106 | def test__bbs_to_percent(self): 107 | p = policies.PolicyContainer(policies.policies_v3()) 108 | bbs = BoundingBoxesOnImage( 109 | [BoundingBox(*bb, label=label) for bb, label in zip([[0, 0, 25, 25]], [0])], 110 | (100, 100) 111 | ) 112 | res = p._bbs_to_percent(bbs, 100, 100) 113 | assert np.allclose( 114 | np.array([[0, 0.125, 0.125, 0.25, 0.25]]), 115 | res 116 | ) 117 | 118 | def test__bbs_to_pixek(self): 119 | p = policies.PolicyContainer(policies.policies_v3()) 120 | bbs = BoundingBoxesOnImage( 121 | [BoundingBox(*bb, label=label) for bb, label in zip([[0, 0, 25, 25]], [1])], 122 | (100, 100) 123 | ) 124 | res = p._bbs_to_pixel(bbs) 125 | assert np.array_equal( 126 | np.array([[1, 0, 0, 25, 25]]), 127 | res 128 | ) 129 | 130 | def test_select_random_policy(self): 131 | p = policies.PolicyContainer(policies.policies_v3()) 132 | random_policy = p.select_random_policy() 133 | assert random_policy in p.policies 134 | 135 | def test_apply_augmentation(self, mocker): 136 | numpy_random_mock = mocker.patch('bbaug.augmentations.augmentations.np.random.random') 137 | numpy_random_mock.return_value = 0.0 138 | bbs_to_percent_mock = mocker.patch('bbaug.policies.policies.PolicyContainer._bbs_to_percent') 139 | bbs_to_pixel_mock = mocker.patch('bbaug.policies.policies.PolicyContainer._bbs_to_pixel') 140 | 141 | def aug_mock(image, bounding_boxes): 142 | return image, bounding_boxes 143 | 144 | bbcutout_mock = mocker.patch('bbaug.augmentations.augmentations.cutout_bbox') 145 | bbcutout_mock.return_value = aug_mock 146 | p = policies.PolicyContainer( 147 | policies.policies_v3(), 148 | name_to_augmentation={'Cutout_BBox': bbcutout_mock}, 149 | ) 150 | policy = [policies.POLICY_TUPLE('Cutout_BBox', 0.2, 10)] 151 | bbs = [[0, 0, 25, 25]] 152 | p.apply_augmentation(policy, np.zeros((100, 100, 3)).astype('uint8'), bbs, [0]) 153 | assert bbcutout_mock.called 154 | bbcutout_mock.assert_called_with(10, height=100, width=100) 155 | assert not bbs_to_percent_mock.called 156 | assert bbs_to_pixel_mock.called 157 | 158 | bbs_to_percent_mock.reset_mock() 159 | bbs_to_pixel_mock.reset_mock() 160 | numpy_random_mock.return_value = 1.0 161 | bbcutout_mock.reset_mock() 162 | p = policies.PolicyContainer( 163 | policies.policies_v3(), 164 | name_to_augmentation={'Cutout_BBox': bbcutout_mock}, 165 | ) 166 | policy = [policies.POLICY_TUPLE('Cutout_BBox', 0.2, 10)] 167 | bbs = [[0, 0, 25, 25]] 168 | p.apply_augmentation(policy, np.zeros((100, 100, 3)).astype('uint8'), bbs, [0]) 169 | assert not bbcutout_mock.called 170 | assert not bbs_to_percent_mock.called 171 | assert bbs_to_pixel_mock.called 172 | 173 | bbs_to_percent_mock.reset_mock() 174 | bbs_to_pixel_mock.reset_mock() 175 | numpy_random_mock.return_value = 0.0 176 | bbcutout_mock.reset_mock() 177 | p = policies.PolicyContainer( 178 | policies.policies_v3(), 179 | name_to_augmentation={'BBox': bbcutout_mock}, 180 | ) 181 | policy = [policies.POLICY_TUPLE('BBox', 0.2, 10)] 182 | bbs = [[0, 0, 25, 25]] 183 | p.apply_augmentation(policy, np.zeros((100, 100, 3)).astype('uint8'), bbs, [0]) 184 | assert bbcutout_mock.called 185 | assert not bbs_to_percent_mock.called 186 | assert bbs_to_pixel_mock.called 187 | 188 | bbs_to_percent_mock.reset_mock() 189 | bbs_to_pixel_mock.reset_mock() 190 | numpy_random_mock.return_value = 0.0 191 | cutout_mock = mocker.patch('bbaug.augmentations.augmentations.cutout') 192 | cutout_mock.return_value = aug_mock 193 | p = policies.PolicyContainer( 194 | policies.policies_v3(), 195 | name_to_augmentation={'Cutout': cutout_mock}, 196 | ) 197 | policy = [policies.POLICY_TUPLE('Cutout', 0.2, 10)] 198 | bbs = [[0, 0, 25, 25]] 199 | p.apply_augmentation(policy, np.zeros((250, 300, 3)).astype('uint8'), bbs, [0]) 200 | assert cutout_mock.called 201 | assert not bbs_to_percent_mock.called 202 | assert bbs_to_pixel_mock.called 203 | cutout_mock.assert_called_with(10, height=250, width=300) 204 | 205 | numpy_random_mock.return_value = 0.0 206 | bbs_to_percent_mock.reset_mock() 207 | bbs_to_pixel_mock.reset_mock() 208 | colour_mock = mocker.patch('bbaug.augmentations.augmentations.colour') 209 | colour_mock.return_value = aug_mock 210 | p = policies.PolicyContainer( 211 | policies.policies_v3(), 212 | name_to_augmentation={'Color': colour_mock}, 213 | return_yolo=True 214 | ) 215 | policy = [policies.POLICY_TUPLE('Color', 0.2, 10)] 216 | p.apply_augmentation(policy, np.zeros((100, 100, 3)).astype('uint8'), bbs, [1]) 217 | assert colour_mock.called 218 | assert bbs_to_percent_mock.called 219 | assert not bbs_to_pixel_mock.called 220 | 221 | 222 | cutout_fraction_mock = mocker.patch('bbaug.augmentations.augmentations.cutout_fraction') 223 | cutout_fraction_mock.return_value = aug_mock 224 | p = policies.PolicyContainer( 225 | policies.policies_v3(), 226 | name_to_augmentation={'Cutout_Fraction': cutout_fraction_mock}, 227 | ) 228 | policy = [policies.POLICY_TUPLE('Cutout_Fraction', 0.2, 10)] 229 | p.apply_augmentation(policy, np.zeros((100, 100, 3)).astype('uint8'), [], []) 230 | assert cutout_fraction_mock.called 231 | cutout_fraction_mock.assert_called_with(10) 232 | 233 | cutout_fraction_mock.reset_mock() 234 | cutout_fraction_mock.return_value = aug_mock 235 | p = policies.PolicyContainer( 236 | policies.policies_v3(), 237 | name_to_augmentation={'Cutout_Fraction': cutout_fraction_mock}, 238 | ) 239 | policy = [policies.POLICY_TUPLE('Cutout_Fraction', 0.2, 10)] 240 | bbs = [[0, 0, 25, 25]] 241 | p.apply_augmentation(policy, np.zeros((100, 100, 3)).astype('uint8'), bbs, [5]) 242 | assert cutout_fraction_mock.called 243 | cutout_fraction_mock.assert_called_with(10, height=100, width=100, height_bbox=25, width_bbox=25) 244 | -------------------------------------------------------------------------------- /tests/test_visuals.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import call 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from bbaug.policies import ( 7 | POLICY_TUPLE, 8 | PolicyContainer, 9 | policies_v3 10 | ) 11 | from bbaug.visuals import visualise_policy 12 | 13 | def test_visualise_policy(mocker): 14 | mock_plt = mocker.patch('bbaug.visuals.visuals.plt') 15 | colour_mock = mocker.patch('bbaug.augmentations.augmentations.colour') 16 | imgio_mock = mocker.patch('bbaug.visuals.visuals.imageio') 17 | policy_container_mock = mocker.patch('bbaug.visuals.visuals.PolicyContainer') 18 | 19 | def aug_mock(image, bounding_boxes): 20 | return image, bounding_boxes 21 | 22 | class MockShape: 23 | 24 | @property 25 | def shape(self): 26 | return 100, 100, 3 27 | 28 | axes_mock = mocker.MagicMock() 29 | fig_mock = mocker.MagicMock() 30 | mock_plt.subplots.return_value = (fig_mock, axes_mock) 31 | colour_mock.return_value = aug_mock 32 | imgio_mock.imread.return_value = MockShape() 33 | policy_container_mock().apply_augmentation.return_value = ( 34 | np.zeros((100, 100, 3)).astype('uint8'), 35 | np.array([[9, 0, 50, 25, 75]]) 36 | ) 37 | visualise_policy( 38 | './test/image/dir.png', 39 | './test/save/dir', 40 | [[0, 50, 25, 75]], 41 | [9], 42 | [[POLICY_TUPLE('Color', 0.2, 10)]], 43 | name_to_augmentation={'Color': colour_mock} 44 | ) 45 | 46 | mock_plt.subplots.assert_called_with(1, 3, figsize=(15, 4)) 47 | imgio_mock.imread.assert_called_with('./test/image/dir.png') 48 | fig_mock.suptitle.assert_called_with([POLICY_TUPLE('Color', pytest.approx(0.2), 10)]) 49 | assert fig_mock.savefig.call_args_list == [call('./test/save/dir/sub_policy_0.png')] 50 | --------------------------------------------------------------------------------