├── .deepsource.toml ├── .github └── workflows │ ├── pypi.yml │ └── tests.yml ├── .gitignore ├── Dockerfile.dev ├── LICENSE ├── MANIFEST.in ├── README.md ├── __init__.py ├── requirements.txt ├── setup.py ├── tests ├── test_base.py └── test_transforms.py └── ttach ├── __init__.py ├── __version__.py ├── aliases.py ├── base.py ├── functional.py ├── transforms.py └── wrappers.py /.deepsource.toml: -------------------------------------------------------------------------------- 1 | version = 1 2 | 3 | [[analyzers]] 4 | name = "python" 5 | enabled = true 6 | runtime_version = "3.x.x" 7 | -------------------------------------------------------------------------------- /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 13 | uses: actions/setup-python@v2 14 | with: 15 | python-version: '3.6' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install setuptools wheel twine 20 | - name: Build and publish 21 | env: 22 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 23 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 24 | run: | 25 | python setup.py sdist bdist_wheel 26 | twine upload dist/* 27 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: CI 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | test: 14 | 15 | runs-on: ubuntu-18.04 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: 3.6 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | python -m pip install codecov pytest 27 | pip install . 28 | pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 29 | - name: Test 30 | run: | 31 | python -m pytest -s tests 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea/ 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ -------------------------------------------------------------------------------- /Dockerfile.dev: -------------------------------------------------------------------------------- 1 | FROM anibali/pytorch:no-cuda 2 | 3 | # install requirements 4 | RUN pip install pytest 5 | 6 | # copy project 7 | COPY . /project 8 | WORKDIR /project 9 | 10 | # install project 11 | RUN pip install . 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2019, Pavel Yakubovskiy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md LICENSE requirements.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TTAch 2 | Image Test Time Augmentation with PyTorch! 3 | 4 | Similar to what Data Augmentation is doing to the training set, the purpose of Test Time Augmentation is to perform random modifications to the test images. Thus, instead of showing the regular, “clean” images, only once to the trained model, we will show it the augmented images several times. We will then average the predictions of each corresponding image and take that as our final guess [[1](https://towardsdatascience.com/test-time-augmentation-tta-and-how-to-perform-it-with-keras-4ac19b67fb4d)]. 5 | ``` 6 | Input 7 | | # input batch of images 8 | / / /|\ \ \ # apply augmentations (flips, rotation, scale, etc.) 9 | | | | | | | | # pass augmented batches through model 10 | | | | | | | | # reverse transformations for each batch of masks/labels 11 | \ \ \ / / / # merge predictions (mean, max, gmean, etc.) 12 | | # output batch of masks/labels 13 | Output 14 | ``` 15 | ## Table of Contents 16 | 1. [Quick Start](#quick-start) 17 | 2. [Transforms](#transforms) 18 | 3. [Aliases](#aliases) 19 | 4. [Merge modes](#merge-modes) 20 | 5. [Installation](#installation) 21 | 22 | ## Quick start 23 | 24 | ##### Segmentation model wrapping [[docstring](ttach/wrappers.py#L8)]: 25 | ```python 26 | import ttach as tta 27 | tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean') 28 | ``` 29 | ##### Classification model wrapping [[docstring](ttach/wrappers.py#L52)]: 30 | ```python 31 | tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform()) 32 | ``` 33 | 34 | ##### Keypoints model wrapping [[docstring](ttach/wrappers.py#L96)]: 35 | ```python 36 | tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True) 37 | ``` 38 | **Note**: the model must return keypoints in the format `torch([x1, y1, ..., xn, yn])` 39 | 40 | ## Advanced Examples 41 | ##### Custom transform: 42 | ```python 43 | # defined 2 * 2 * 3 * 3 = 36 augmentations ! 44 | transforms = tta.Compose( 45 | [ 46 | tta.HorizontalFlip(), 47 | tta.Rotate90(angles=[0, 180]), 48 | tta.Scale(scales=[1, 2, 4]), 49 | tta.Multiply(factors=[0.9, 1, 1.1]), 50 | ] 51 | ) 52 | 53 | tta_model = tta.SegmentationTTAWrapper(model, transforms) 54 | ``` 55 | ##### Custom model (multi-input / multi-output) 56 | ```python 57 | # Example how to process ONE batch on images with TTA 58 | # Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N) 59 | 60 | for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform() 61 | 62 | # augment image 63 | augmented_image = transformer.augment_image(image) 64 | 65 | # pass to model 66 | model_output = model(augmented_image, another_input_data) 67 | 68 | # reverse augmentation for mask and label 69 | deaug_mask = transformer.deaugment_mask(model_output['mask']) 70 | deaug_label = transformer.deaugment_label(model_output['label']) 71 | 72 | # save results 73 | labels.append(deaug_mask) 74 | masks.append(deaug_label) 75 | 76 | # reduce results as you want, e.g mean/max/min 77 | label = mean(labels) 78 | mask = mean(masks) 79 | ``` 80 | 81 | ## Transforms 82 | 83 | | Transform | Parameters | Values | 84 | |----------------|:-------------------------:|:---------------------------------:| 85 | | HorizontalFlip | - | - | 86 | | VerticalFlip | - | - | 87 | | Rotate90 | angles | List\[0, 90, 180, 270] | 88 | | Scale | scales
interpolation | List\[float]
"nearest"/"linear"| 89 | | Resize | sizes
original_size
interpolation | List\[Tuple\[int, int]]
Tuple\[int,int]
"nearest"/"linear"| 90 | | Add | values | List\[float] | 91 | | Multiply | factors | List\[float] | 92 | | FiveCrops | crop_height
crop_width | int
int | 93 | 94 | ## Aliases 95 | 96 | - flip_transform (horizontal + vertical flips) 97 | - hflip_transform (horizontal flip) 98 | - d4_transform (flips + rotation 0, 90, 180, 270) 99 | - multiscale_transform (scale transform, take scales as input parameter) 100 | - five_crop_transform (corner crops + center crop) 101 | - ten_crop_transform (five crops + five crops on horizontal flip) 102 | 103 | ## Merge modes 104 | - mean 105 | - gmean (geometric mean) 106 | - sum 107 | - max 108 | - min 109 | - tsharpen ([temperature sharpen](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/107716#latest-624046) with t=0.5) 110 | 111 | ## Installation 112 | PyPI: 113 | ```bash 114 | $ pip install ttach 115 | ``` 116 | Source: 117 | ```bash 118 | $ pip install git+https://github.com/qubvel/ttach 119 | ``` 120 | 121 | ## Run tests 122 | 123 | ```bash 124 | docker build -f Dockerfile.dev -t ttach:dev . && docker run --rm ttach:dev pytest -p no:cacheprovider 125 | ``` 126 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .ttach import * 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qubvel/ttach/94e579e59a21cbdfbb4f5790502e648008ecf64e/requirements.txt -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Note: To use the 'upload' functionality of this file, you must: 5 | # $ pip install twine 6 | 7 | import io 8 | import os 9 | import sys 10 | from shutil import rmtree 11 | 12 | from setuptools import find_packages, setup, Command 13 | 14 | # Package meta-data. 15 | NAME = 'ttach' 16 | DESCRIPTION = 'Images test time augmentation with PyTorch.' 17 | URL = 'https://github.com/qubvel/ttach' 18 | EMAIL = 'qubvel@gmail.com' 19 | AUTHOR = 'Pavel Yakubovskiy' 20 | REQUIRES_PYTHON = '>=3.0.0' 21 | VERSION = None 22 | 23 | # The rest you shouldn't have to touch too much :) 24 | # ------------------------------------------------ 25 | # Except, perhaps the License and Trove Classifiers! 26 | # If you do change the License, remember to change the Trove Classifier for that! 27 | 28 | here = os.path.abspath(os.path.dirname(__file__)) 29 | 30 | # What packages are required for this module to be executed? 31 | try: 32 | with open(os.path.join(here, 'requirements.txt'), encoding='utf-8') as f: 33 | REQUIRED = f.read().split('\n') 34 | except: 35 | REQUIRED = [] 36 | 37 | # What packages are optional? 38 | EXTRAS = { 39 | 'test': ['pytest'] 40 | } 41 | 42 | # Import the README and use it as the long-description. 43 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 44 | try: 45 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 46 | long_description = '\n' + f.read() 47 | except FileNotFoundError: 48 | long_description = DESCRIPTION 49 | 50 | # Load the package's __version__.py module as a dictionary. 51 | about = {} 52 | if not VERSION: 53 | with open(os.path.join(here, NAME, '__version__.py')) as f: 54 | exec(f.read(), about) 55 | else: 56 | about['__version__'] = VERSION 57 | 58 | 59 | class UploadCommand(Command): 60 | """Support setup.py upload.""" 61 | 62 | description = 'Build and publish the package.' 63 | user_options = [] 64 | 65 | @staticmethod 66 | def status(s): 67 | """Prints things in bold.""" 68 | print(s) 69 | 70 | def initialize_options(self): 71 | pass 72 | 73 | def finalize_options(self): 74 | pass 75 | 76 | def run(self): 77 | try: 78 | self.status('Removing previous builds...') 79 | rmtree(os.path.join(here, 'dist')) 80 | except OSError: 81 | pass 82 | 83 | self.status('Building Source and Wheel (universal) distribution...') 84 | os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable)) 85 | 86 | self.status('Uploading the package to PyPI via Twine...') 87 | os.system('twine upload dist/*') 88 | 89 | self.status('Pushing git tags...') 90 | os.system('git tag v{0}'.format(about['__version__'])) 91 | os.system('git push --tags') 92 | 93 | sys.exit() 94 | 95 | 96 | # Where the magic happens: 97 | setup( 98 | name=NAME, 99 | version=about['__version__'], 100 | description=DESCRIPTION, 101 | long_description=long_description, 102 | long_description_content_type='text/markdown', 103 | author=AUTHOR, 104 | author_email=EMAIL, 105 | python_requires=REQUIRES_PYTHON, 106 | url=URL, 107 | packages=find_packages(exclude=('tests', 'docs', 'images')), 108 | # If your package is a single module, use this instead of 'packages': 109 | # py_modules=['mypackage'], 110 | 111 | # entry_points={ 112 | # 'console_scripts': ['mycli=mymodule:cli'], 113 | # }, 114 | install_requires=REQUIRED, 115 | extras_require=EXTRAS, 116 | include_package_data=True, 117 | license='MIT', 118 | classifiers=[ 119 | # Trove classifiers 120 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 121 | 'License :: OSI Approved :: MIT License', 122 | 'Programming Language :: Python', 123 | 'Programming Language :: Python :: 3', 124 | 'Programming Language :: Python :: Implementation :: CPython', 125 | 'Programming Language :: Python :: Implementation :: PyPy' 126 | ], 127 | # $ setup.py publish support. 128 | cmdclass={ 129 | 'upload': UploadCommand, 130 | }, 131 | ) 132 | -------------------------------------------------------------------------------- /tests/test_base.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import ttach as tta 4 | 5 | 6 | def test_compose_1(): 7 | transform = tta.Compose( 8 | [ 9 | tta.HorizontalFlip(), 10 | tta.VerticalFlip(), 11 | tta.Rotate90(angles=[0, 90, 180, 270]), 12 | tta.Scale(scales=[1, 2, 4], interpolation="nearest"), 13 | ] 14 | ) 15 | 16 | assert len(transform) == 2 * 2 * 4 * 3 # all combinations for aug parameters 17 | 18 | dummy_label = torch.ones(2).reshape(2, 1).float() 19 | dummy_image = torch.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5).float() 20 | dummy_model = lambda x: {"label": dummy_label, "mask": x} 21 | 22 | for augmenter in transform: 23 | augmented_image = augmenter.augment_image(dummy_image) 24 | model_output = dummy_model(augmented_image) 25 | deaugmented_mask = augmenter.deaugment_mask(model_output["mask"]) 26 | deaugmented_label = augmenter.deaugment_label(model_output["label"]) 27 | assert torch.allclose(deaugmented_mask, dummy_image) 28 | assert torch.allclose(deaugmented_label, dummy_label) 29 | 30 | 31 | @pytest.mark.parametrize( 32 | "case", 33 | [ 34 | ("mean", 0.5), 35 | ("gmean", 0.0), 36 | ("max", 1.0), 37 | ("min", 0.0), 38 | ("sum", 1.5), 39 | ("tsharpen", 0.56903558), 40 | ], 41 | ) 42 | def test_merger(case): 43 | merge_type, output = case 44 | input = [1.0, 0.0, 0.5] 45 | merger = tta.base.Merger(type=merge_type, n=len(input)) 46 | for i in input: 47 | merger.append(torch.tensor(i)) 48 | assert torch.allclose(merger.result, torch.tensor(output)) 49 | -------------------------------------------------------------------------------- /tests/test_transforms.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import ttach as tta 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "transform", 8 | [ 9 | tta.HorizontalFlip(), 10 | tta.VerticalFlip(), 11 | tta.Rotate90(angles=[0, 90, 180, 270]), 12 | tta.Scale(scales=[1, 2, 4], interpolation="nearest"), 13 | tta.Resize(sizes=[(4, 5), (8, 10)], original_size=(4, 5), interpolation="nearest") 14 | ], 15 | ) 16 | def test_aug_deaug_mask(transform): 17 | a = torch.arange(20).reshape(1, 1, 4, 5).float() 18 | for p in transform.params: 19 | aug = transform.apply_aug_image(a, **{transform.pname: p}) 20 | deaug = transform.apply_deaug_mask(aug, **{transform.pname: p}) 21 | assert torch.allclose(a, deaug) 22 | 23 | 24 | @pytest.mark.parametrize( 25 | "transform", 26 | [ 27 | tta.HorizontalFlip(), 28 | tta.VerticalFlip(), 29 | tta.Rotate90(angles=[0, 90, 180, 270]), 30 | tta.Scale(scales=[1, 2, 4], interpolation="nearest"), 31 | tta.Add(values=[-1, 0, 1, 2]), 32 | tta.Multiply(factors=[-1, 0, 1, 2]), 33 | tta.FiveCrops(crop_height=3, crop_width=5), 34 | tta.Resize(sizes=[(4, 5), (8, 10), (2, 2)], interpolation="nearest") 35 | ], 36 | ) 37 | def test_label_is_same(transform): 38 | a = torch.arange(20).reshape(1, 1, 4, 5).float() 39 | for p in transform.params: 40 | aug = transform.apply_aug_image(a, **{transform.pname: p}) 41 | deaug = transform.apply_deaug_label(aug, **{transform.pname: p}) 42 | assert torch.allclose(aug, deaug) 43 | 44 | 45 | @pytest.mark.parametrize( 46 | "transform", 47 | [ 48 | tta.HorizontalFlip(), 49 | tta.VerticalFlip() 50 | ], 51 | ) 52 | def test_flip_keypoints(transform): 53 | keypoints = torch.tensor([[0.1, 0.1], [0.1, 0.9], [0.9, 0.1], [0.9, 0.9], [0.4, 0.3]]) 54 | for p in transform.params: 55 | aug = transform.apply_deaug_keypoints(keypoints.detach().clone(), **{transform.pname: p}) 56 | deaug = transform.apply_deaug_keypoints(aug, **{transform.pname: p}) 57 | assert torch.allclose(keypoints, deaug) 58 | 59 | 60 | @pytest.mark.parametrize( 61 | "transform", 62 | [ 63 | tta.Rotate90(angles=[0, 90, 180, 270]) 64 | ], 65 | ) 66 | def test_rotate90_keypoints(transform): 67 | keypoints = torch.tensor([[0.1, 0.1], [0.1, 0.9], [0.9, 0.1], [0.9, 0.9], [0.4, 0.3]]) 68 | for p in transform.params: 69 | aug = transform.apply_deaug_keypoints(keypoints.detach().clone(), **{transform.pname: p}) 70 | deaug = transform.apply_deaug_keypoints(aug, **{transform.pname: -p}) 71 | assert torch.allclose(keypoints, deaug) 72 | 73 | 74 | def test_add_transform(): 75 | transform = tta.Add(values=[-1, 0, 1]) 76 | a = torch.arange(20).reshape(1, 1, 4, 5).float() 77 | for p in transform.params: 78 | aug = transform.apply_aug_image(a, **{transform.pname: p}) 79 | assert torch.allclose(aug, a + p) 80 | 81 | 82 | def test_multiply_transform(): 83 | transform = tta.Multiply(factors=[-1, 0, 1]) 84 | a = torch.arange(20).reshape(1, 1, 4, 5).float() 85 | for p in transform.params: 86 | aug = transform.apply_aug_image(a, **{transform.pname: p}) 87 | assert torch.allclose(aug, a * p) 88 | 89 | 90 | def test_fivecrop_transform(): 91 | transform = tta.FiveCrops(crop_height=1, crop_width=1) 92 | a = torch.arange(25).reshape(1, 1, 5, 5).float() 93 | output = [0, 20, 24, 4, 12] 94 | for i, p in enumerate(transform.params): 95 | aug = transform.apply_aug_image(a, **{transform.pname: p}) 96 | assert aug.item() == output[i] 97 | 98 | # 99 | # def test_resize_transform(): 100 | # transform = tta.Resize(sizes=[(10, 10), (5, 5)], original_size=(5, 5)) 101 | # a = torch.arange(25).reshape(1, 1, 5, 5).float() 102 | # for i, p in enumerate(transform.params): 103 | # aug = transform.apply_aug_image(a, **{transform.pname: p}) 104 | # assert aug.item() == output[i] -------------------------------------------------------------------------------- /ttach/__init__.py: -------------------------------------------------------------------------------- 1 | from .wrappers import ( 2 | SegmentationTTAWrapper, 3 | ClassificationTTAWrapper, 4 | KeypointsTTAWrapper 5 | ) 6 | from .base import Compose 7 | 8 | from .transforms import ( 9 | HorizontalFlip, VerticalFlip, Rotate90, Scale, Add, Multiply, FiveCrops, Resize 10 | ) 11 | 12 | from . import aliases 13 | 14 | from .__version__ import __version__ 15 | -------------------------------------------------------------------------------- /ttach/__version__.py: -------------------------------------------------------------------------------- 1 | VERSION = (0, 0, 3) 2 | 3 | __version__ = '.'.join(map(str, VERSION)) 4 | -------------------------------------------------------------------------------- /ttach/aliases.py: -------------------------------------------------------------------------------- 1 | from .base import Compose 2 | from . import transforms as tta 3 | 4 | 5 | def flip_transform(): 6 | return Compose([tta.HorizontalFlip(), tta.VerticalFlip()]) 7 | 8 | 9 | def hflip_transform(): 10 | return Compose([tta.HorizontalFlip()]) 11 | 12 | 13 | def vflip_transform(): 14 | return Compose([tta.VerticalFlip()]) 15 | 16 | 17 | def d4_transform(): 18 | return Compose( 19 | [ 20 | tta.HorizontalFlip(), 21 | tta.Rotate90(angles=[0, 90, 180, 270]), 22 | ] 23 | ) 24 | 25 | def multiscale_transform(scales, interpolation="nearest"): 26 | return Compose([tta.Scale(scales, interpolation=interpolation)]) 27 | 28 | 29 | def five_crop_transform(crop_height, crop_width): 30 | return Compose([tta.FiveCrops(crop_height, crop_width)]) 31 | 32 | 33 | def ten_crop_transform(crop_height, crop_width): 34 | return Compose([tta.HorizontalFlip(), tta.FiveCrops(crop_height, crop_width)]) 35 | -------------------------------------------------------------------------------- /ttach/base.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from functools import partial 3 | from typing import List, Optional, Union 4 | 5 | from . import functional as F 6 | 7 | 8 | class BaseTransform: 9 | identity_param = None 10 | 11 | def __init__( 12 | self, 13 | name: str, 14 | params: Union[list, tuple], 15 | ): 16 | self.params = params 17 | self.pname = name 18 | 19 | def apply_aug_image(self, image, *args, **params): 20 | raise NotImplementedError 21 | 22 | def apply_deaug_mask(self, mask, *args, **params): 23 | raise NotImplementedError 24 | 25 | def apply_deaug_label(self, label, *args, **params): 26 | raise NotImplementedError 27 | 28 | def apply_deaug_keypoints(self, keypoints, *args, **params): 29 | raise NotImplementedError 30 | 31 | 32 | class ImageOnlyTransform(BaseTransform): 33 | 34 | def apply_deaug_mask(self, mask, *args, **params): 35 | return mask 36 | 37 | def apply_deaug_label(self, label, *args, **params): 38 | return label 39 | 40 | def apply_deaug_keypoints(self, keypoints, *args, **params): 41 | return keypoints 42 | 43 | 44 | class DualTransform(BaseTransform): 45 | pass 46 | 47 | 48 | class Chain: 49 | 50 | def __init__( 51 | self, 52 | functions: List[callable] 53 | ): 54 | self.functions = functions or [] 55 | 56 | def __call__(self, x): 57 | for f in self.functions: 58 | x = f(x) 59 | return x 60 | 61 | 62 | class Transformer: 63 | def __init__( 64 | self, 65 | image_pipeline: Chain, 66 | mask_pipeline: Chain, 67 | label_pipeline: Chain, 68 | keypoints_pipeline: Chain 69 | ): 70 | self.image_pipeline = image_pipeline 71 | self.mask_pipeline = mask_pipeline 72 | self.label_pipeline = label_pipeline 73 | self.keypoints_pipeline = keypoints_pipeline 74 | 75 | def augment_image(self, image): 76 | return self.image_pipeline(image) 77 | 78 | def deaugment_mask(self, mask): 79 | return self.mask_pipeline(mask) 80 | 81 | def deaugment_label(self, label): 82 | return self.label_pipeline(label) 83 | 84 | def deaugment_keypoints(self, keypoints): 85 | return self.keypoints_pipeline(keypoints) 86 | 87 | 88 | class Compose: 89 | 90 | def __init__( 91 | self, 92 | transforms: List[BaseTransform], 93 | ): 94 | self.aug_transforms = transforms 95 | self.aug_transform_parameters = list(itertools.product(*[t.params for t in self.aug_transforms])) 96 | self.deaug_transforms = transforms[::-1] 97 | self.deaug_transform_parameters = [p[::-1] for p in self.aug_transform_parameters] 98 | 99 | def __iter__(self) -> Transformer: 100 | for aug_params, deaug_params in zip(self.aug_transform_parameters, self.deaug_transform_parameters): 101 | image_aug_chain = Chain([partial(t.apply_aug_image, **{t.pname: p}) 102 | for t, p in zip(self.aug_transforms, aug_params)]) 103 | mask_deaug_chain = Chain([partial(t.apply_deaug_mask, **{t.pname: p}) 104 | for t, p in zip(self.deaug_transforms, deaug_params)]) 105 | label_deaug_chain = Chain([partial(t.apply_deaug_label, **{t.pname: p}) 106 | for t, p in zip(self.deaug_transforms, deaug_params)]) 107 | keypoints_deaug_chain = Chain([partial(t.apply_deaug_keypoints, **{t.pname: p}) 108 | for t, p in zip(self.deaug_transforms, deaug_params)]) 109 | yield Transformer( 110 | image_pipeline=image_aug_chain, 111 | mask_pipeline=mask_deaug_chain, 112 | label_pipeline=label_deaug_chain, 113 | keypoints_pipeline=keypoints_deaug_chain 114 | ) 115 | 116 | def __len__(self) -> int: 117 | return len(self.aug_transform_parameters) 118 | 119 | 120 | class Merger: 121 | 122 | def __init__( 123 | self, 124 | type: str = 'mean', 125 | n: int = 1, 126 | ): 127 | 128 | if type not in ['mean', 'gmean', 'sum', 'max', 'min', 'tsharpen']: 129 | raise ValueError('Not correct merge type `{}`.'.format(type)) 130 | 131 | self.output = None 132 | self.type = type 133 | self.n = n 134 | 135 | def append(self, x): 136 | 137 | if self.type == 'tsharpen': 138 | x = x ** 0.5 139 | 140 | if self.output is None: 141 | self.output = x 142 | elif self.type in ['mean', 'sum', 'tsharpen']: 143 | self.output = self.output + x 144 | elif self.type == 'gmean': 145 | self.output = self.output * x 146 | elif self.type == 'max': 147 | self.output = F.max(self.output, x) 148 | elif self.type == 'min': 149 | self.output = F.min(self.output, x) 150 | 151 | @property 152 | def result(self): 153 | if self.type in ['sum', 'max', 'min']: 154 | result = self.output 155 | elif self.type in ['mean', 'tsharpen']: 156 | result = self.output / self.n 157 | elif self.type in ['gmean']: 158 | result = self.output ** (1 / self.n) 159 | else: 160 | raise ValueError('Not correct merge type `{}`.'.format(self.type)) 161 | return result 162 | -------------------------------------------------------------------------------- /ttach/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def rot90(x, k=1): 6 | """rotate batch of images by 90 degrees k times""" 7 | return torch.rot90(x, k, (2, 3)) 8 | 9 | 10 | def hflip(x): 11 | """flip batch of images horizontally""" 12 | return x.flip(3) 13 | 14 | 15 | def vflip(x): 16 | """flip batch of images vertically""" 17 | return x.flip(2) 18 | 19 | 20 | def sum(x1, x2): 21 | """sum of two tensors""" 22 | return x1 + x2 23 | 24 | 25 | def add(x, value): 26 | """add value to tensor""" 27 | return x + value 28 | 29 | 30 | def max(x1, x2): 31 | """compare 2 tensors and take max values""" 32 | return torch.max(x1, x2) 33 | 34 | 35 | def min(x1, x2): 36 | """compare 2 tensors and take min values""" 37 | return torch.min(x1, x2) 38 | 39 | 40 | def multiply(x, factor): 41 | """multiply tensor by factor""" 42 | return x * factor 43 | 44 | 45 | def scale(x, scale_factor, interpolation="nearest", align_corners=None): 46 | """scale batch of images by `scale_factor` with given interpolation mode""" 47 | h, w = x.shape[2:] 48 | new_h = int(h * scale_factor) 49 | new_w = int(w * scale_factor) 50 | return F.interpolate( 51 | x, size=(new_h, new_w), mode=interpolation, align_corners=align_corners 52 | ) 53 | 54 | 55 | def resize(x, size, interpolation="nearest", align_corners=None): 56 | """resize batch of images to given spatial size with given interpolation mode""" 57 | return F.interpolate(x, size=size, mode=interpolation, align_corners=align_corners) 58 | 59 | 60 | def crop(x, x_min=None, x_max=None, y_min=None, y_max=None): 61 | """perform crop on batch of images""" 62 | return x[:, :, y_min:y_max, x_min:x_max] 63 | 64 | 65 | def crop_lt(x, crop_h, crop_w): 66 | """crop left top corner""" 67 | return x[:, :, 0:crop_h, 0:crop_w] 68 | 69 | 70 | def crop_lb(x, crop_h, crop_w): 71 | """crop left bottom corner""" 72 | return x[:, :, -crop_h:, 0:crop_w] 73 | 74 | 75 | def crop_rt(x, crop_h, crop_w): 76 | """crop right top corner""" 77 | return x[:, :, 0:crop_h, -crop_w:] 78 | 79 | 80 | def crop_rb(x, crop_h, crop_w): 81 | """crop right bottom corner""" 82 | return x[:, :, -crop_h:, -crop_w:] 83 | 84 | 85 | def center_crop(x, crop_h, crop_w): 86 | """make center crop""" 87 | 88 | center_h = x.shape[2] // 2 89 | center_w = x.shape[3] // 2 90 | half_crop_h = crop_h // 2 91 | half_crop_w = crop_w // 2 92 | 93 | y_min = center_h - half_crop_h 94 | y_max = center_h + half_crop_h + crop_h % 2 95 | x_min = center_w - half_crop_w 96 | x_max = center_w + half_crop_w + crop_w % 2 97 | 98 | return x[:, :, y_min:y_max, x_min:x_max] 99 | 100 | 101 | def _disassemble_keypoints(keypoints): 102 | x = keypoints[..., 0] 103 | y = keypoints[..., 1] 104 | return x, y 105 | 106 | 107 | def _assemble_keypoints(x, y): 108 | return torch.stack([x, y], dim=-1) 109 | 110 | 111 | def keypoints_hflip(keypoints): 112 | x, y = _disassemble_keypoints(keypoints) 113 | return _assemble_keypoints(1. - x, y) 114 | 115 | 116 | def keypoints_vflip(keypoints): 117 | x, y = _disassemble_keypoints(keypoints) 118 | return _assemble_keypoints(x, 1. - y) 119 | 120 | 121 | def keypoints_rot90(keypoints, k=1): 122 | 123 | if k not in {0, 1, 2, 3}: 124 | raise ValueError("Parameter k must be in [0:3]") 125 | if k == 0: 126 | return keypoints 127 | x, y = _disassemble_keypoints(keypoints) 128 | 129 | if k == 1: 130 | xy = [y, 1. - x] 131 | elif k == 2: 132 | xy = [1. - x, 1. - y] 133 | elif k == 3: 134 | xy = [1. - y, x] 135 | 136 | return _assemble_keypoints(*xy) 137 | -------------------------------------------------------------------------------- /ttach/transforms.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Optional, List, Union, Tuple 3 | from . import functional as F 4 | from .base import DualTransform, ImageOnlyTransform 5 | 6 | 7 | class HorizontalFlip(DualTransform): 8 | """Flip images horizontally (left->right)""" 9 | 10 | identity_param = False 11 | 12 | def __init__(self): 13 | super().__init__("apply", [False, True]) 14 | 15 | def apply_aug_image(self, image, apply=False, **kwargs): 16 | if apply: 17 | image = F.hflip(image) 18 | return image 19 | 20 | def apply_deaug_mask(self, mask, apply=False, **kwargs): 21 | if apply: 22 | mask = F.hflip(mask) 23 | return mask 24 | 25 | def apply_deaug_label(self, label, apply=False, **kwargs): 26 | return label 27 | 28 | def apply_deaug_keypoints(self, keypoints, apply=False, **kwargs): 29 | if apply: 30 | keypoints = F.keypoints_hflip(keypoints) 31 | return keypoints 32 | 33 | 34 | class VerticalFlip(DualTransform): 35 | """Flip images vertically (up->down)""" 36 | 37 | identity_param = False 38 | 39 | def __init__(self): 40 | super().__init__("apply", [False, True]) 41 | 42 | def apply_aug_image(self, image, apply=False, **kwargs): 43 | if apply: 44 | image = F.vflip(image) 45 | return image 46 | 47 | def apply_deaug_mask(self, mask, apply=False, **kwargs): 48 | if apply: 49 | mask = F.vflip(mask) 50 | return mask 51 | 52 | def apply_deaug_label(self, label, apply=False, **kwargs): 53 | return label 54 | 55 | def apply_deaug_keypoints(self, keypoints, apply=False, **kwargs): 56 | if apply: 57 | keypoints = F.keypoints_vflip(keypoints) 58 | return keypoints 59 | 60 | 61 | class Rotate90(DualTransform): 62 | """Rotate images 0/90/180/270 degrees 63 | 64 | Args: 65 | angles (list): angles to rotate images 66 | """ 67 | 68 | identity_param = 0 69 | 70 | def __init__(self, angles: List[int]): 71 | if self.identity_param not in angles: 72 | angles = [self.identity_param] + list(angles) 73 | 74 | super().__init__("angle", angles) 75 | 76 | def apply_aug_image(self, image, angle=0, **kwargs): 77 | k = angle // 90 if angle >= 0 else (angle + 360) // 90 78 | return F.rot90(image, k) 79 | 80 | def apply_deaug_mask(self, mask, angle=0, **kwargs): 81 | return self.apply_aug_image(mask, -angle) 82 | 83 | def apply_deaug_label(self, label, angle=0, **kwargs): 84 | return label 85 | 86 | def apply_deaug_keypoints(self, keypoints, angle=0, **kwargs): 87 | angle *= -1 88 | k = angle // 90 if angle >= 0 else (angle + 360) // 90 89 | return F.keypoints_rot90(keypoints, k=k) 90 | 91 | 92 | class Scale(DualTransform): 93 | """Scale images 94 | 95 | Args: 96 | scales (List[Union[int, float]]): scale factors for spatial image dimensions 97 | interpolation (str): one of "nearest"/"lenear" (see more in torch.nn.interpolate) 98 | align_corners (bool): see more in torch.nn.interpolate 99 | """ 100 | 101 | identity_param = 1 102 | 103 | def __init__( 104 | self, 105 | scales: List[Union[int, float]], 106 | interpolation: str = "nearest", 107 | align_corners: Optional[bool] = None, 108 | ): 109 | if self.identity_param not in scales: 110 | scales = [self.identity_param] + list(scales) 111 | self.interpolation = interpolation 112 | self.align_corners = align_corners 113 | 114 | super().__init__("scale", scales) 115 | 116 | def apply_aug_image(self, image, scale=1, **kwargs): 117 | if scale != self.identity_param: 118 | image = F.scale( 119 | image, 120 | scale, 121 | interpolation=self.interpolation, 122 | align_corners=self.align_corners, 123 | ) 124 | return image 125 | 126 | def apply_deaug_mask(self, mask, scale=1, **kwargs): 127 | if scale != self.identity_param: 128 | mask = F.scale( 129 | mask, 130 | 1 / scale, 131 | interpolation=self.interpolation, 132 | align_corners=self.align_corners, 133 | ) 134 | return mask 135 | 136 | def apply_deaug_label(self, label, scale=1, **kwargs): 137 | return label 138 | 139 | def apply_deaug_keypoints(self, keypoints, scale=1, **kwargs): 140 | return keypoints 141 | 142 | 143 | class Resize(DualTransform): 144 | """Resize images 145 | 146 | Args: 147 | sizes (List[Tuple[int, int]): scale factors for spatial image dimensions 148 | original_size Tuple(int, int): optional, image original size for deaugmenting mask 149 | interpolation (str): one of "nearest"/"lenear" (see more in torch.nn.interpolate) 150 | align_corners (bool): see more in torch.nn.interpolate 151 | """ 152 | 153 | def __init__( 154 | self, 155 | sizes: List[Tuple[int, int]], 156 | original_size: Tuple[int, int] = None, 157 | interpolation: str = "nearest", 158 | align_corners: Optional[bool] = None, 159 | ): 160 | if original_size is not None and original_size not in sizes: 161 | sizes = [original_size] + list(sizes) 162 | self.interpolation = interpolation 163 | self.align_corners = align_corners 164 | self.original_size = original_size 165 | 166 | super().__init__("size", sizes) 167 | 168 | def apply_aug_image(self, image, size, **kwargs): 169 | if size != self.original_size: 170 | image = F.resize( 171 | image, 172 | size, 173 | interpolation=self.interpolation, 174 | align_corners=self.align_corners, 175 | ) 176 | return image 177 | 178 | def apply_deaug_mask(self, mask, size, **kwargs): 179 | if self.original_size is None: 180 | raise ValueError( 181 | "Provide original image size to make mask backward transformation" 182 | ) 183 | if size != self.original_size: 184 | mask = F.resize( 185 | mask, 186 | self.original_size, 187 | interpolation=self.interpolation, 188 | align_corners=self.align_corners, 189 | ) 190 | return mask 191 | 192 | def apply_deaug_label(self, label, size=1, **kwargs): 193 | return label 194 | 195 | def apply_deaug_keypoints(self, keypoints, size=1, **kwargs): 196 | return keypoints 197 | 198 | 199 | class Add(ImageOnlyTransform): 200 | """Add value to images 201 | 202 | Args: 203 | values (List[float]): values to add to each pixel 204 | """ 205 | 206 | identity_param = 0 207 | 208 | def __init__(self, values: List[float]): 209 | 210 | if self.identity_param not in values: 211 | values = [self.identity_param] + list(values) 212 | super().__init__("value", values) 213 | 214 | def apply_aug_image(self, image, value=0, **kwargs): 215 | if value != self.identity_param: 216 | image = F.add(image, value) 217 | return image 218 | 219 | 220 | class Multiply(ImageOnlyTransform): 221 | """Multiply images by factor 222 | 223 | Args: 224 | factors (List[float]): factor to multiply each pixel by 225 | """ 226 | 227 | identity_param = 1 228 | 229 | def __init__(self, factors: List[float]): 230 | if self.identity_param not in factors: 231 | factors = [self.identity_param] + list(factors) 232 | super().__init__("factor", factors) 233 | 234 | def apply_aug_image(self, image, factor=1, **kwargs): 235 | if factor != self.identity_param: 236 | image = F.multiply(image, factor) 237 | return image 238 | 239 | 240 | class FiveCrops(ImageOnlyTransform): 241 | """Makes 4 crops for each corner + center crop 242 | 243 | Args: 244 | crop_height (int): crop height in pixels 245 | crop_width (int): crop width in pixels 246 | """ 247 | 248 | def __init__(self, crop_height, crop_width): 249 | crop_functions = ( 250 | partial(F.crop_lt, crop_h=crop_height, crop_w=crop_width), 251 | partial(F.crop_lb, crop_h=crop_height, crop_w=crop_width), 252 | partial(F.crop_rb, crop_h=crop_height, crop_w=crop_width), 253 | partial(F.crop_rt, crop_h=crop_height, crop_w=crop_width), 254 | partial(F.center_crop, crop_h=crop_height, crop_w=crop_width), 255 | ) 256 | super().__init__("crop_fn", crop_functions) 257 | 258 | def apply_aug_image(self, image, crop_fn=None, **kwargs): 259 | return crop_fn(image) 260 | 261 | def apply_deaug_mask(self, mask, **kwargs): 262 | raise ValueError("`FiveCrop` augmentation is not suitable for mask!") 263 | 264 | def apply_deaug_keypoints(self, keypoints, **kwargs): 265 | raise ValueError("`FiveCrop` augmentation is not suitable for keypoints!") 266 | -------------------------------------------------------------------------------- /ttach/wrappers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Optional, Mapping, Union, Tuple 4 | 5 | from .base import Merger, Compose 6 | 7 | 8 | class SegmentationTTAWrapper(nn.Module): 9 | """Wrap PyTorch nn.Module (segmentation model) with test time augmentation transforms 10 | 11 | Args: 12 | model (torch.nn.Module): segmentation model with single input and single output 13 | (.forward(x) should return either torch.Tensor or Mapping[str, torch.Tensor]) 14 | transforms (ttach.Compose): composition of test time transforms 15 | merge_mode (str): method to merge augmented predictions mean/gmean/max/min/sum/tsharpen 16 | output_mask_key (str): if model output is `dict`, specify which key belong to `mask` 17 | """ 18 | 19 | def __init__( 20 | self, 21 | model: nn.Module, 22 | transforms: Compose, 23 | merge_mode: str = "mean", 24 | output_mask_key: Optional[str] = None, 25 | ): 26 | super().__init__() 27 | self.model = model 28 | self.transforms = transforms 29 | self.merge_mode = merge_mode 30 | self.output_key = output_mask_key 31 | 32 | def forward( 33 | self, image: torch.Tensor, *args 34 | ) -> Union[torch.Tensor, Mapping[str, torch.Tensor]]: 35 | merger = Merger(type=self.merge_mode, n=len(self.transforms)) 36 | 37 | for transformer in self.transforms: 38 | augmented_image = transformer.augment_image(image) 39 | augmented_output = self.model(augmented_image, *args) 40 | if self.output_key is not None: 41 | augmented_output = augmented_output[self.output_key] 42 | deaugmented_output = transformer.deaugment_mask(augmented_output) 43 | merger.append(deaugmented_output) 44 | 45 | result = merger.result 46 | if self.output_key is not None: 47 | result = {self.output_key: result} 48 | 49 | return result 50 | 51 | 52 | class ClassificationTTAWrapper(nn.Module): 53 | """Wrap PyTorch nn.Module (classification model) with test time augmentation transforms 54 | 55 | Args: 56 | model (torch.nn.Module): classification model with single input and single output 57 | (.forward(x) should return either torch.Tensor or Mapping[str, torch.Tensor]) 58 | transforms (ttach.Compose): composition of test time transforms 59 | merge_mode (str): method to merge augmented predictions mean/gmean/max/min/sum/tsharpen 60 | output_label_key (str): if model output is `dict`, specify which key belong to `label` 61 | """ 62 | 63 | def __init__( 64 | self, 65 | model: nn.Module, 66 | transforms: Compose, 67 | merge_mode: str = "mean", 68 | output_label_key: Optional[str] = None, 69 | ): 70 | super().__init__() 71 | self.model = model 72 | self.transforms = transforms 73 | self.merge_mode = merge_mode 74 | self.output_key = output_label_key 75 | 76 | def forward( 77 | self, image: torch.Tensor, *args 78 | ) -> Union[torch.Tensor, Mapping[str, torch.Tensor]]: 79 | merger = Merger(type=self.merge_mode, n=len(self.transforms)) 80 | 81 | for transformer in self.transforms: 82 | augmented_image = transformer.augment_image(image) 83 | augmented_output = self.model(augmented_image, *args) 84 | if self.output_key is not None: 85 | augmented_output = augmented_output[self.output_key] 86 | deaugmented_output = transformer.deaugment_label(augmented_output) 87 | merger.append(deaugmented_output) 88 | 89 | result = merger.result 90 | if self.output_key is not None: 91 | result = {self.output_key: result} 92 | 93 | return result 94 | 95 | 96 | class KeypointsTTAWrapper(nn.Module): 97 | """Wrap PyTorch nn.Module (keypoints model) with test time augmentation transforms 98 | 99 | Args: 100 | model (torch.nn.Module): keypoints model with single input and single output 101 | in format [x1,y1, x2, y2, ..., xn, yn] 102 | (.forward(x) should return either torch.Tensor or Mapping[str, torch.Tensor]) 103 | transforms (ttach.Compose): composition of test time transforms 104 | merge_mode (str): method to merge augmented predictions mean/gmean/max/min/sum/tsharpen 105 | output_keypoints_key (str): if model output is `dict`, specify which key belong to `label` 106 | scaled (bool): True if model return x, y scaled values in [0, 1], else False 107 | 108 | """ 109 | 110 | def __init__( 111 | self, 112 | model: nn.Module, 113 | transforms: Compose, 114 | merge_mode: str = "mean", 115 | output_keypoints_key: Optional[str] = None, 116 | scaled: bool = False, 117 | ): 118 | super().__init__() 119 | self.model = model 120 | self.transforms = transforms 121 | self.merge_mode = merge_mode 122 | self.output_key = output_keypoints_key 123 | self.scaled = scaled 124 | 125 | def forward( 126 | self, image: torch.Tensor, *args 127 | ) -> Union[torch.Tensor, Mapping[str, torch.Tensor]]: 128 | merger = Merger(type=self.merge_mode, n=len(self.transforms)) 129 | size = image.size() 130 | batch_size, image_height, image_width = size[0], size[2], size[3] 131 | 132 | for transformer in self.transforms: 133 | augmented_image = transformer.augment_image(image) 134 | augmented_output = self.model(augmented_image, *args) 135 | 136 | if self.output_key is not None: 137 | augmented_output = augmented_output[self.output_key] 138 | 139 | augmented_output = augmented_output.reshape(batch_size, -1, 2) 140 | if not self.scaled: 141 | augmented_output[..., 0] /= image_width 142 | augmented_output[..., 1] /= image_height 143 | 144 | deaugmented_output = transformer.deaugment_keypoints(augmented_output) 145 | merger.append(deaugmented_output) 146 | 147 | result = merger.result 148 | 149 | if not self.scaled: 150 | result[..., 0] *= image_width 151 | result[..., 1] *= image_height 152 | result = result.reshape(batch_size, -1) 153 | 154 | if self.output_key is not None: 155 | result = {self.output_key: result} 156 | 157 | return result 158 | --------------------------------------------------------------------------------