├── .deepsource.toml
├── .github
└── workflows
│ ├── pypi.yml
│ └── tests.yml
├── .gitignore
├── Dockerfile.dev
├── LICENSE
├── MANIFEST.in
├── README.md
├── __init__.py
├── requirements.txt
├── setup.py
├── tests
├── test_base.py
└── test_transforms.py
└── ttach
├── __init__.py
├── __version__.py
├── aliases.py
├── base.py
├── functional.py
├── transforms.py
└── wrappers.py
/.deepsource.toml:
--------------------------------------------------------------------------------
1 | version = 1
2 |
3 | [[analyzers]]
4 | name = "python"
5 | enabled = true
6 | runtime_version = "3.x.x"
7 |
--------------------------------------------------------------------------------
/.github/workflows/pypi.yml:
--------------------------------------------------------------------------------
1 | name: Upload Python Package
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v2
12 | - name: Set up Python
13 | uses: actions/setup-python@v2
14 | with:
15 | python-version: '3.6'
16 | - name: Install dependencies
17 | run: |
18 | python -m pip install --upgrade pip
19 | pip install setuptools wheel twine
20 | - name: Build and publish
21 | env:
22 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
23 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
24 | run: |
25 | python setup.py sdist bdist_wheel
26 | twine upload dist/*
27 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: CI
5 |
6 | on:
7 | push:
8 | branches: [ master ]
9 | pull_request:
10 | branches: [ master ]
11 |
12 | jobs:
13 | test:
14 |
15 | runs-on: ubuntu-18.04
16 |
17 | steps:
18 | - uses: actions/checkout@v2
19 | - name: Set up Python ${{ matrix.python-version }}
20 | uses: actions/setup-python@v2
21 | with:
22 | python-version: 3.6
23 | - name: Install dependencies
24 | run: |
25 | python -m pip install --upgrade pip
26 | python -m pip install codecov pytest
27 | pip install .
28 | pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
29 | - name: Test
30 | run: |
31 | python -m pytest -s tests
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | .idea/
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 | .pytest_cache/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 | db.sqlite3
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # Environments
86 | .env
87 | .venv
88 | env/
89 | venv/
90 | ENV/
91 | env.bak/
92 | venv.bak/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
--------------------------------------------------------------------------------
/Dockerfile.dev:
--------------------------------------------------------------------------------
1 | FROM anibali/pytorch:no-cuda
2 |
3 | # install requirements
4 | RUN pip install pytest
5 |
6 | # copy project
7 | COPY . /project
8 | WORKDIR /project
9 |
10 | # install project
11 | RUN pip install .
12 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License
2 |
3 | Copyright (c) 2019, Pavel Yakubovskiy
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md LICENSE requirements.txt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TTAch
2 | Image Test Time Augmentation with PyTorch!
3 |
4 | Similar to what Data Augmentation is doing to the training set, the purpose of Test Time Augmentation is to perform random modifications to the test images. Thus, instead of showing the regular, “clean” images, only once to the trained model, we will show it the augmented images several times. We will then average the predictions of each corresponding image and take that as our final guess [[1](https://towardsdatascience.com/test-time-augmentation-tta-and-how-to-perform-it-with-keras-4ac19b67fb4d)].
5 | ```
6 | Input
7 | | # input batch of images
8 | / / /|\ \ \ # apply augmentations (flips, rotation, scale, etc.)
9 | | | | | | | | # pass augmented batches through model
10 | | | | | | | | # reverse transformations for each batch of masks/labels
11 | \ \ \ / / / # merge predictions (mean, max, gmean, etc.)
12 | | # output batch of masks/labels
13 | Output
14 | ```
15 | ## Table of Contents
16 | 1. [Quick Start](#quick-start)
17 | 2. [Transforms](#transforms)
18 | 3. [Aliases](#aliases)
19 | 4. [Merge modes](#merge-modes)
20 | 5. [Installation](#installation)
21 |
22 | ## Quick start
23 |
24 | ##### Segmentation model wrapping [[docstring](ttach/wrappers.py#L8)]:
25 | ```python
26 | import ttach as tta
27 | tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')
28 | ```
29 | ##### Classification model wrapping [[docstring](ttach/wrappers.py#L52)]:
30 | ```python
31 | tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())
32 | ```
33 |
34 | ##### Keypoints model wrapping [[docstring](ttach/wrappers.py#L96)]:
35 | ```python
36 | tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True)
37 | ```
38 | **Note**: the model must return keypoints in the format `torch([x1, y1, ..., xn, yn])`
39 |
40 | ## Advanced Examples
41 | ##### Custom transform:
42 | ```python
43 | # defined 2 * 2 * 3 * 3 = 36 augmentations !
44 | transforms = tta.Compose(
45 | [
46 | tta.HorizontalFlip(),
47 | tta.Rotate90(angles=[0, 180]),
48 | tta.Scale(scales=[1, 2, 4]),
49 | tta.Multiply(factors=[0.9, 1, 1.1]),
50 | ]
51 | )
52 |
53 | tta_model = tta.SegmentationTTAWrapper(model, transforms)
54 | ```
55 | ##### Custom model (multi-input / multi-output)
56 | ```python
57 | # Example how to process ONE batch on images with TTA
58 | # Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)
59 |
60 | for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform()
61 |
62 | # augment image
63 | augmented_image = transformer.augment_image(image)
64 |
65 | # pass to model
66 | model_output = model(augmented_image, another_input_data)
67 |
68 | # reverse augmentation for mask and label
69 | deaug_mask = transformer.deaugment_mask(model_output['mask'])
70 | deaug_label = transformer.deaugment_label(model_output['label'])
71 |
72 | # save results
73 | labels.append(deaug_mask)
74 | masks.append(deaug_label)
75 |
76 | # reduce results as you want, e.g mean/max/min
77 | label = mean(labels)
78 | mask = mean(masks)
79 | ```
80 |
81 | ## Transforms
82 |
83 | | Transform | Parameters | Values |
84 | |----------------|:-------------------------:|:---------------------------------:|
85 | | HorizontalFlip | - | - |
86 | | VerticalFlip | - | - |
87 | | Rotate90 | angles | List\[0, 90, 180, 270] |
88 | | Scale | scales
interpolation | List\[float]
"nearest"/"linear"|
89 | | Resize | sizes
original_size
interpolation | List\[Tuple\[int, int]]
Tuple\[int,int]
"nearest"/"linear"|
90 | | Add | values | List\[float] |
91 | | Multiply | factors | List\[float] |
92 | | FiveCrops | crop_height
crop_width | int
int |
93 |
94 | ## Aliases
95 |
96 | - flip_transform (horizontal + vertical flips)
97 | - hflip_transform (horizontal flip)
98 | - d4_transform (flips + rotation 0, 90, 180, 270)
99 | - multiscale_transform (scale transform, take scales as input parameter)
100 | - five_crop_transform (corner crops + center crop)
101 | - ten_crop_transform (five crops + five crops on horizontal flip)
102 |
103 | ## Merge modes
104 | - mean
105 | - gmean (geometric mean)
106 | - sum
107 | - max
108 | - min
109 | - tsharpen ([temperature sharpen](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/107716#latest-624046) with t=0.5)
110 |
111 | ## Installation
112 | PyPI:
113 | ```bash
114 | $ pip install ttach
115 | ```
116 | Source:
117 | ```bash
118 | $ pip install git+https://github.com/qubvel/ttach
119 | ```
120 |
121 | ## Run tests
122 |
123 | ```bash
124 | docker build -f Dockerfile.dev -t ttach:dev . && docker run --rm ttach:dev pytest -p no:cacheprovider
125 | ```
126 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .ttach import *
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qubvel/ttach/94e579e59a21cbdfbb4f5790502e648008ecf64e/requirements.txt
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # Note: To use the 'upload' functionality of this file, you must:
5 | # $ pip install twine
6 |
7 | import io
8 | import os
9 | import sys
10 | from shutil import rmtree
11 |
12 | from setuptools import find_packages, setup, Command
13 |
14 | # Package meta-data.
15 | NAME = 'ttach'
16 | DESCRIPTION = 'Images test time augmentation with PyTorch.'
17 | URL = 'https://github.com/qubvel/ttach'
18 | EMAIL = 'qubvel@gmail.com'
19 | AUTHOR = 'Pavel Yakubovskiy'
20 | REQUIRES_PYTHON = '>=3.0.0'
21 | VERSION = None
22 |
23 | # The rest you shouldn't have to touch too much :)
24 | # ------------------------------------------------
25 | # Except, perhaps the License and Trove Classifiers!
26 | # If you do change the License, remember to change the Trove Classifier for that!
27 |
28 | here = os.path.abspath(os.path.dirname(__file__))
29 |
30 | # What packages are required for this module to be executed?
31 | try:
32 | with open(os.path.join(here, 'requirements.txt'), encoding='utf-8') as f:
33 | REQUIRED = f.read().split('\n')
34 | except:
35 | REQUIRED = []
36 |
37 | # What packages are optional?
38 | EXTRAS = {
39 | 'test': ['pytest']
40 | }
41 |
42 | # Import the README and use it as the long-description.
43 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file!
44 | try:
45 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f:
46 | long_description = '\n' + f.read()
47 | except FileNotFoundError:
48 | long_description = DESCRIPTION
49 |
50 | # Load the package's __version__.py module as a dictionary.
51 | about = {}
52 | if not VERSION:
53 | with open(os.path.join(here, NAME, '__version__.py')) as f:
54 | exec(f.read(), about)
55 | else:
56 | about['__version__'] = VERSION
57 |
58 |
59 | class UploadCommand(Command):
60 | """Support setup.py upload."""
61 |
62 | description = 'Build and publish the package.'
63 | user_options = []
64 |
65 | @staticmethod
66 | def status(s):
67 | """Prints things in bold."""
68 | print(s)
69 |
70 | def initialize_options(self):
71 | pass
72 |
73 | def finalize_options(self):
74 | pass
75 |
76 | def run(self):
77 | try:
78 | self.status('Removing previous builds...')
79 | rmtree(os.path.join(here, 'dist'))
80 | except OSError:
81 | pass
82 |
83 | self.status('Building Source and Wheel (universal) distribution...')
84 | os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable))
85 |
86 | self.status('Uploading the package to PyPI via Twine...')
87 | os.system('twine upload dist/*')
88 |
89 | self.status('Pushing git tags...')
90 | os.system('git tag v{0}'.format(about['__version__']))
91 | os.system('git push --tags')
92 |
93 | sys.exit()
94 |
95 |
96 | # Where the magic happens:
97 | setup(
98 | name=NAME,
99 | version=about['__version__'],
100 | description=DESCRIPTION,
101 | long_description=long_description,
102 | long_description_content_type='text/markdown',
103 | author=AUTHOR,
104 | author_email=EMAIL,
105 | python_requires=REQUIRES_PYTHON,
106 | url=URL,
107 | packages=find_packages(exclude=('tests', 'docs', 'images')),
108 | # If your package is a single module, use this instead of 'packages':
109 | # py_modules=['mypackage'],
110 |
111 | # entry_points={
112 | # 'console_scripts': ['mycli=mymodule:cli'],
113 | # },
114 | install_requires=REQUIRED,
115 | extras_require=EXTRAS,
116 | include_package_data=True,
117 | license='MIT',
118 | classifiers=[
119 | # Trove classifiers
120 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
121 | 'License :: OSI Approved :: MIT License',
122 | 'Programming Language :: Python',
123 | 'Programming Language :: Python :: 3',
124 | 'Programming Language :: Python :: Implementation :: CPython',
125 | 'Programming Language :: Python :: Implementation :: PyPy'
126 | ],
127 | # $ setup.py publish support.
128 | cmdclass={
129 | 'upload': UploadCommand,
130 | },
131 | )
132 |
--------------------------------------------------------------------------------
/tests/test_base.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import ttach as tta
4 |
5 |
6 | def test_compose_1():
7 | transform = tta.Compose(
8 | [
9 | tta.HorizontalFlip(),
10 | tta.VerticalFlip(),
11 | tta.Rotate90(angles=[0, 90, 180, 270]),
12 | tta.Scale(scales=[1, 2, 4], interpolation="nearest"),
13 | ]
14 | )
15 |
16 | assert len(transform) == 2 * 2 * 4 * 3 # all combinations for aug parameters
17 |
18 | dummy_label = torch.ones(2).reshape(2, 1).float()
19 | dummy_image = torch.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5).float()
20 | dummy_model = lambda x: {"label": dummy_label, "mask": x}
21 |
22 | for augmenter in transform:
23 | augmented_image = augmenter.augment_image(dummy_image)
24 | model_output = dummy_model(augmented_image)
25 | deaugmented_mask = augmenter.deaugment_mask(model_output["mask"])
26 | deaugmented_label = augmenter.deaugment_label(model_output["label"])
27 | assert torch.allclose(deaugmented_mask, dummy_image)
28 | assert torch.allclose(deaugmented_label, dummy_label)
29 |
30 |
31 | @pytest.mark.parametrize(
32 | "case",
33 | [
34 | ("mean", 0.5),
35 | ("gmean", 0.0),
36 | ("max", 1.0),
37 | ("min", 0.0),
38 | ("sum", 1.5),
39 | ("tsharpen", 0.56903558),
40 | ],
41 | )
42 | def test_merger(case):
43 | merge_type, output = case
44 | input = [1.0, 0.0, 0.5]
45 | merger = tta.base.Merger(type=merge_type, n=len(input))
46 | for i in input:
47 | merger.append(torch.tensor(i))
48 | assert torch.allclose(merger.result, torch.tensor(output))
49 |
--------------------------------------------------------------------------------
/tests/test_transforms.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import ttach as tta
4 |
5 |
6 | @pytest.mark.parametrize(
7 | "transform",
8 | [
9 | tta.HorizontalFlip(),
10 | tta.VerticalFlip(),
11 | tta.Rotate90(angles=[0, 90, 180, 270]),
12 | tta.Scale(scales=[1, 2, 4], interpolation="nearest"),
13 | tta.Resize(sizes=[(4, 5), (8, 10)], original_size=(4, 5), interpolation="nearest")
14 | ],
15 | )
16 | def test_aug_deaug_mask(transform):
17 | a = torch.arange(20).reshape(1, 1, 4, 5).float()
18 | for p in transform.params:
19 | aug = transform.apply_aug_image(a, **{transform.pname: p})
20 | deaug = transform.apply_deaug_mask(aug, **{transform.pname: p})
21 | assert torch.allclose(a, deaug)
22 |
23 |
24 | @pytest.mark.parametrize(
25 | "transform",
26 | [
27 | tta.HorizontalFlip(),
28 | tta.VerticalFlip(),
29 | tta.Rotate90(angles=[0, 90, 180, 270]),
30 | tta.Scale(scales=[1, 2, 4], interpolation="nearest"),
31 | tta.Add(values=[-1, 0, 1, 2]),
32 | tta.Multiply(factors=[-1, 0, 1, 2]),
33 | tta.FiveCrops(crop_height=3, crop_width=5),
34 | tta.Resize(sizes=[(4, 5), (8, 10), (2, 2)], interpolation="nearest")
35 | ],
36 | )
37 | def test_label_is_same(transform):
38 | a = torch.arange(20).reshape(1, 1, 4, 5).float()
39 | for p in transform.params:
40 | aug = transform.apply_aug_image(a, **{transform.pname: p})
41 | deaug = transform.apply_deaug_label(aug, **{transform.pname: p})
42 | assert torch.allclose(aug, deaug)
43 |
44 |
45 | @pytest.mark.parametrize(
46 | "transform",
47 | [
48 | tta.HorizontalFlip(),
49 | tta.VerticalFlip()
50 | ],
51 | )
52 | def test_flip_keypoints(transform):
53 | keypoints = torch.tensor([[0.1, 0.1], [0.1, 0.9], [0.9, 0.1], [0.9, 0.9], [0.4, 0.3]])
54 | for p in transform.params:
55 | aug = transform.apply_deaug_keypoints(keypoints.detach().clone(), **{transform.pname: p})
56 | deaug = transform.apply_deaug_keypoints(aug, **{transform.pname: p})
57 | assert torch.allclose(keypoints, deaug)
58 |
59 |
60 | @pytest.mark.parametrize(
61 | "transform",
62 | [
63 | tta.Rotate90(angles=[0, 90, 180, 270])
64 | ],
65 | )
66 | def test_rotate90_keypoints(transform):
67 | keypoints = torch.tensor([[0.1, 0.1], [0.1, 0.9], [0.9, 0.1], [0.9, 0.9], [0.4, 0.3]])
68 | for p in transform.params:
69 | aug = transform.apply_deaug_keypoints(keypoints.detach().clone(), **{transform.pname: p})
70 | deaug = transform.apply_deaug_keypoints(aug, **{transform.pname: -p})
71 | assert torch.allclose(keypoints, deaug)
72 |
73 |
74 | def test_add_transform():
75 | transform = tta.Add(values=[-1, 0, 1])
76 | a = torch.arange(20).reshape(1, 1, 4, 5).float()
77 | for p in transform.params:
78 | aug = transform.apply_aug_image(a, **{transform.pname: p})
79 | assert torch.allclose(aug, a + p)
80 |
81 |
82 | def test_multiply_transform():
83 | transform = tta.Multiply(factors=[-1, 0, 1])
84 | a = torch.arange(20).reshape(1, 1, 4, 5).float()
85 | for p in transform.params:
86 | aug = transform.apply_aug_image(a, **{transform.pname: p})
87 | assert torch.allclose(aug, a * p)
88 |
89 |
90 | def test_fivecrop_transform():
91 | transform = tta.FiveCrops(crop_height=1, crop_width=1)
92 | a = torch.arange(25).reshape(1, 1, 5, 5).float()
93 | output = [0, 20, 24, 4, 12]
94 | for i, p in enumerate(transform.params):
95 | aug = transform.apply_aug_image(a, **{transform.pname: p})
96 | assert aug.item() == output[i]
97 |
98 | #
99 | # def test_resize_transform():
100 | # transform = tta.Resize(sizes=[(10, 10), (5, 5)], original_size=(5, 5))
101 | # a = torch.arange(25).reshape(1, 1, 5, 5).float()
102 | # for i, p in enumerate(transform.params):
103 | # aug = transform.apply_aug_image(a, **{transform.pname: p})
104 | # assert aug.item() == output[i]
--------------------------------------------------------------------------------
/ttach/__init__.py:
--------------------------------------------------------------------------------
1 | from .wrappers import (
2 | SegmentationTTAWrapper,
3 | ClassificationTTAWrapper,
4 | KeypointsTTAWrapper
5 | )
6 | from .base import Compose
7 |
8 | from .transforms import (
9 | HorizontalFlip, VerticalFlip, Rotate90, Scale, Add, Multiply, FiveCrops, Resize
10 | )
11 |
12 | from . import aliases
13 |
14 | from .__version__ import __version__
15 |
--------------------------------------------------------------------------------
/ttach/__version__.py:
--------------------------------------------------------------------------------
1 | VERSION = (0, 0, 3)
2 |
3 | __version__ = '.'.join(map(str, VERSION))
4 |
--------------------------------------------------------------------------------
/ttach/aliases.py:
--------------------------------------------------------------------------------
1 | from .base import Compose
2 | from . import transforms as tta
3 |
4 |
5 | def flip_transform():
6 | return Compose([tta.HorizontalFlip(), tta.VerticalFlip()])
7 |
8 |
9 | def hflip_transform():
10 | return Compose([tta.HorizontalFlip()])
11 |
12 |
13 | def vflip_transform():
14 | return Compose([tta.VerticalFlip()])
15 |
16 |
17 | def d4_transform():
18 | return Compose(
19 | [
20 | tta.HorizontalFlip(),
21 | tta.Rotate90(angles=[0, 90, 180, 270]),
22 | ]
23 | )
24 |
25 | def multiscale_transform(scales, interpolation="nearest"):
26 | return Compose([tta.Scale(scales, interpolation=interpolation)])
27 |
28 |
29 | def five_crop_transform(crop_height, crop_width):
30 | return Compose([tta.FiveCrops(crop_height, crop_width)])
31 |
32 |
33 | def ten_crop_transform(crop_height, crop_width):
34 | return Compose([tta.HorizontalFlip(), tta.FiveCrops(crop_height, crop_width)])
35 |
--------------------------------------------------------------------------------
/ttach/base.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from functools import partial
3 | from typing import List, Optional, Union
4 |
5 | from . import functional as F
6 |
7 |
8 | class BaseTransform:
9 | identity_param = None
10 |
11 | def __init__(
12 | self,
13 | name: str,
14 | params: Union[list, tuple],
15 | ):
16 | self.params = params
17 | self.pname = name
18 |
19 | def apply_aug_image(self, image, *args, **params):
20 | raise NotImplementedError
21 |
22 | def apply_deaug_mask(self, mask, *args, **params):
23 | raise NotImplementedError
24 |
25 | def apply_deaug_label(self, label, *args, **params):
26 | raise NotImplementedError
27 |
28 | def apply_deaug_keypoints(self, keypoints, *args, **params):
29 | raise NotImplementedError
30 |
31 |
32 | class ImageOnlyTransform(BaseTransform):
33 |
34 | def apply_deaug_mask(self, mask, *args, **params):
35 | return mask
36 |
37 | def apply_deaug_label(self, label, *args, **params):
38 | return label
39 |
40 | def apply_deaug_keypoints(self, keypoints, *args, **params):
41 | return keypoints
42 |
43 |
44 | class DualTransform(BaseTransform):
45 | pass
46 |
47 |
48 | class Chain:
49 |
50 | def __init__(
51 | self,
52 | functions: List[callable]
53 | ):
54 | self.functions = functions or []
55 |
56 | def __call__(self, x):
57 | for f in self.functions:
58 | x = f(x)
59 | return x
60 |
61 |
62 | class Transformer:
63 | def __init__(
64 | self,
65 | image_pipeline: Chain,
66 | mask_pipeline: Chain,
67 | label_pipeline: Chain,
68 | keypoints_pipeline: Chain
69 | ):
70 | self.image_pipeline = image_pipeline
71 | self.mask_pipeline = mask_pipeline
72 | self.label_pipeline = label_pipeline
73 | self.keypoints_pipeline = keypoints_pipeline
74 |
75 | def augment_image(self, image):
76 | return self.image_pipeline(image)
77 |
78 | def deaugment_mask(self, mask):
79 | return self.mask_pipeline(mask)
80 |
81 | def deaugment_label(self, label):
82 | return self.label_pipeline(label)
83 |
84 | def deaugment_keypoints(self, keypoints):
85 | return self.keypoints_pipeline(keypoints)
86 |
87 |
88 | class Compose:
89 |
90 | def __init__(
91 | self,
92 | transforms: List[BaseTransform],
93 | ):
94 | self.aug_transforms = transforms
95 | self.aug_transform_parameters = list(itertools.product(*[t.params for t in self.aug_transforms]))
96 | self.deaug_transforms = transforms[::-1]
97 | self.deaug_transform_parameters = [p[::-1] for p in self.aug_transform_parameters]
98 |
99 | def __iter__(self) -> Transformer:
100 | for aug_params, deaug_params in zip(self.aug_transform_parameters, self.deaug_transform_parameters):
101 | image_aug_chain = Chain([partial(t.apply_aug_image, **{t.pname: p})
102 | for t, p in zip(self.aug_transforms, aug_params)])
103 | mask_deaug_chain = Chain([partial(t.apply_deaug_mask, **{t.pname: p})
104 | for t, p in zip(self.deaug_transforms, deaug_params)])
105 | label_deaug_chain = Chain([partial(t.apply_deaug_label, **{t.pname: p})
106 | for t, p in zip(self.deaug_transforms, deaug_params)])
107 | keypoints_deaug_chain = Chain([partial(t.apply_deaug_keypoints, **{t.pname: p})
108 | for t, p in zip(self.deaug_transforms, deaug_params)])
109 | yield Transformer(
110 | image_pipeline=image_aug_chain,
111 | mask_pipeline=mask_deaug_chain,
112 | label_pipeline=label_deaug_chain,
113 | keypoints_pipeline=keypoints_deaug_chain
114 | )
115 |
116 | def __len__(self) -> int:
117 | return len(self.aug_transform_parameters)
118 |
119 |
120 | class Merger:
121 |
122 | def __init__(
123 | self,
124 | type: str = 'mean',
125 | n: int = 1,
126 | ):
127 |
128 | if type not in ['mean', 'gmean', 'sum', 'max', 'min', 'tsharpen']:
129 | raise ValueError('Not correct merge type `{}`.'.format(type))
130 |
131 | self.output = None
132 | self.type = type
133 | self.n = n
134 |
135 | def append(self, x):
136 |
137 | if self.type == 'tsharpen':
138 | x = x ** 0.5
139 |
140 | if self.output is None:
141 | self.output = x
142 | elif self.type in ['mean', 'sum', 'tsharpen']:
143 | self.output = self.output + x
144 | elif self.type == 'gmean':
145 | self.output = self.output * x
146 | elif self.type == 'max':
147 | self.output = F.max(self.output, x)
148 | elif self.type == 'min':
149 | self.output = F.min(self.output, x)
150 |
151 | @property
152 | def result(self):
153 | if self.type in ['sum', 'max', 'min']:
154 | result = self.output
155 | elif self.type in ['mean', 'tsharpen']:
156 | result = self.output / self.n
157 | elif self.type in ['gmean']:
158 | result = self.output ** (1 / self.n)
159 | else:
160 | raise ValueError('Not correct merge type `{}`.'.format(self.type))
161 | return result
162 |
--------------------------------------------------------------------------------
/ttach/functional.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def rot90(x, k=1):
6 | """rotate batch of images by 90 degrees k times"""
7 | return torch.rot90(x, k, (2, 3))
8 |
9 |
10 | def hflip(x):
11 | """flip batch of images horizontally"""
12 | return x.flip(3)
13 |
14 |
15 | def vflip(x):
16 | """flip batch of images vertically"""
17 | return x.flip(2)
18 |
19 |
20 | def sum(x1, x2):
21 | """sum of two tensors"""
22 | return x1 + x2
23 |
24 |
25 | def add(x, value):
26 | """add value to tensor"""
27 | return x + value
28 |
29 |
30 | def max(x1, x2):
31 | """compare 2 tensors and take max values"""
32 | return torch.max(x1, x2)
33 |
34 |
35 | def min(x1, x2):
36 | """compare 2 tensors and take min values"""
37 | return torch.min(x1, x2)
38 |
39 |
40 | def multiply(x, factor):
41 | """multiply tensor by factor"""
42 | return x * factor
43 |
44 |
45 | def scale(x, scale_factor, interpolation="nearest", align_corners=None):
46 | """scale batch of images by `scale_factor` with given interpolation mode"""
47 | h, w = x.shape[2:]
48 | new_h = int(h * scale_factor)
49 | new_w = int(w * scale_factor)
50 | return F.interpolate(
51 | x, size=(new_h, new_w), mode=interpolation, align_corners=align_corners
52 | )
53 |
54 |
55 | def resize(x, size, interpolation="nearest", align_corners=None):
56 | """resize batch of images to given spatial size with given interpolation mode"""
57 | return F.interpolate(x, size=size, mode=interpolation, align_corners=align_corners)
58 |
59 |
60 | def crop(x, x_min=None, x_max=None, y_min=None, y_max=None):
61 | """perform crop on batch of images"""
62 | return x[:, :, y_min:y_max, x_min:x_max]
63 |
64 |
65 | def crop_lt(x, crop_h, crop_w):
66 | """crop left top corner"""
67 | return x[:, :, 0:crop_h, 0:crop_w]
68 |
69 |
70 | def crop_lb(x, crop_h, crop_w):
71 | """crop left bottom corner"""
72 | return x[:, :, -crop_h:, 0:crop_w]
73 |
74 |
75 | def crop_rt(x, crop_h, crop_w):
76 | """crop right top corner"""
77 | return x[:, :, 0:crop_h, -crop_w:]
78 |
79 |
80 | def crop_rb(x, crop_h, crop_w):
81 | """crop right bottom corner"""
82 | return x[:, :, -crop_h:, -crop_w:]
83 |
84 |
85 | def center_crop(x, crop_h, crop_w):
86 | """make center crop"""
87 |
88 | center_h = x.shape[2] // 2
89 | center_w = x.shape[3] // 2
90 | half_crop_h = crop_h // 2
91 | half_crop_w = crop_w // 2
92 |
93 | y_min = center_h - half_crop_h
94 | y_max = center_h + half_crop_h + crop_h % 2
95 | x_min = center_w - half_crop_w
96 | x_max = center_w + half_crop_w + crop_w % 2
97 |
98 | return x[:, :, y_min:y_max, x_min:x_max]
99 |
100 |
101 | def _disassemble_keypoints(keypoints):
102 | x = keypoints[..., 0]
103 | y = keypoints[..., 1]
104 | return x, y
105 |
106 |
107 | def _assemble_keypoints(x, y):
108 | return torch.stack([x, y], dim=-1)
109 |
110 |
111 | def keypoints_hflip(keypoints):
112 | x, y = _disassemble_keypoints(keypoints)
113 | return _assemble_keypoints(1. - x, y)
114 |
115 |
116 | def keypoints_vflip(keypoints):
117 | x, y = _disassemble_keypoints(keypoints)
118 | return _assemble_keypoints(x, 1. - y)
119 |
120 |
121 | def keypoints_rot90(keypoints, k=1):
122 |
123 | if k not in {0, 1, 2, 3}:
124 | raise ValueError("Parameter k must be in [0:3]")
125 | if k == 0:
126 | return keypoints
127 | x, y = _disassemble_keypoints(keypoints)
128 |
129 | if k == 1:
130 | xy = [y, 1. - x]
131 | elif k == 2:
132 | xy = [1. - x, 1. - y]
133 | elif k == 3:
134 | xy = [1. - y, x]
135 |
136 | return _assemble_keypoints(*xy)
137 |
--------------------------------------------------------------------------------
/ttach/transforms.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Optional, List, Union, Tuple
3 | from . import functional as F
4 | from .base import DualTransform, ImageOnlyTransform
5 |
6 |
7 | class HorizontalFlip(DualTransform):
8 | """Flip images horizontally (left->right)"""
9 |
10 | identity_param = False
11 |
12 | def __init__(self):
13 | super().__init__("apply", [False, True])
14 |
15 | def apply_aug_image(self, image, apply=False, **kwargs):
16 | if apply:
17 | image = F.hflip(image)
18 | return image
19 |
20 | def apply_deaug_mask(self, mask, apply=False, **kwargs):
21 | if apply:
22 | mask = F.hflip(mask)
23 | return mask
24 |
25 | def apply_deaug_label(self, label, apply=False, **kwargs):
26 | return label
27 |
28 | def apply_deaug_keypoints(self, keypoints, apply=False, **kwargs):
29 | if apply:
30 | keypoints = F.keypoints_hflip(keypoints)
31 | return keypoints
32 |
33 |
34 | class VerticalFlip(DualTransform):
35 | """Flip images vertically (up->down)"""
36 |
37 | identity_param = False
38 |
39 | def __init__(self):
40 | super().__init__("apply", [False, True])
41 |
42 | def apply_aug_image(self, image, apply=False, **kwargs):
43 | if apply:
44 | image = F.vflip(image)
45 | return image
46 |
47 | def apply_deaug_mask(self, mask, apply=False, **kwargs):
48 | if apply:
49 | mask = F.vflip(mask)
50 | return mask
51 |
52 | def apply_deaug_label(self, label, apply=False, **kwargs):
53 | return label
54 |
55 | def apply_deaug_keypoints(self, keypoints, apply=False, **kwargs):
56 | if apply:
57 | keypoints = F.keypoints_vflip(keypoints)
58 | return keypoints
59 |
60 |
61 | class Rotate90(DualTransform):
62 | """Rotate images 0/90/180/270 degrees
63 |
64 | Args:
65 | angles (list): angles to rotate images
66 | """
67 |
68 | identity_param = 0
69 |
70 | def __init__(self, angles: List[int]):
71 | if self.identity_param not in angles:
72 | angles = [self.identity_param] + list(angles)
73 |
74 | super().__init__("angle", angles)
75 |
76 | def apply_aug_image(self, image, angle=0, **kwargs):
77 | k = angle // 90 if angle >= 0 else (angle + 360) // 90
78 | return F.rot90(image, k)
79 |
80 | def apply_deaug_mask(self, mask, angle=0, **kwargs):
81 | return self.apply_aug_image(mask, -angle)
82 |
83 | def apply_deaug_label(self, label, angle=0, **kwargs):
84 | return label
85 |
86 | def apply_deaug_keypoints(self, keypoints, angle=0, **kwargs):
87 | angle *= -1
88 | k = angle // 90 if angle >= 0 else (angle + 360) // 90
89 | return F.keypoints_rot90(keypoints, k=k)
90 |
91 |
92 | class Scale(DualTransform):
93 | """Scale images
94 |
95 | Args:
96 | scales (List[Union[int, float]]): scale factors for spatial image dimensions
97 | interpolation (str): one of "nearest"/"lenear" (see more in torch.nn.interpolate)
98 | align_corners (bool): see more in torch.nn.interpolate
99 | """
100 |
101 | identity_param = 1
102 |
103 | def __init__(
104 | self,
105 | scales: List[Union[int, float]],
106 | interpolation: str = "nearest",
107 | align_corners: Optional[bool] = None,
108 | ):
109 | if self.identity_param not in scales:
110 | scales = [self.identity_param] + list(scales)
111 | self.interpolation = interpolation
112 | self.align_corners = align_corners
113 |
114 | super().__init__("scale", scales)
115 |
116 | def apply_aug_image(self, image, scale=1, **kwargs):
117 | if scale != self.identity_param:
118 | image = F.scale(
119 | image,
120 | scale,
121 | interpolation=self.interpolation,
122 | align_corners=self.align_corners,
123 | )
124 | return image
125 |
126 | def apply_deaug_mask(self, mask, scale=1, **kwargs):
127 | if scale != self.identity_param:
128 | mask = F.scale(
129 | mask,
130 | 1 / scale,
131 | interpolation=self.interpolation,
132 | align_corners=self.align_corners,
133 | )
134 | return mask
135 |
136 | def apply_deaug_label(self, label, scale=1, **kwargs):
137 | return label
138 |
139 | def apply_deaug_keypoints(self, keypoints, scale=1, **kwargs):
140 | return keypoints
141 |
142 |
143 | class Resize(DualTransform):
144 | """Resize images
145 |
146 | Args:
147 | sizes (List[Tuple[int, int]): scale factors for spatial image dimensions
148 | original_size Tuple(int, int): optional, image original size for deaugmenting mask
149 | interpolation (str): one of "nearest"/"lenear" (see more in torch.nn.interpolate)
150 | align_corners (bool): see more in torch.nn.interpolate
151 | """
152 |
153 | def __init__(
154 | self,
155 | sizes: List[Tuple[int, int]],
156 | original_size: Tuple[int, int] = None,
157 | interpolation: str = "nearest",
158 | align_corners: Optional[bool] = None,
159 | ):
160 | if original_size is not None and original_size not in sizes:
161 | sizes = [original_size] + list(sizes)
162 | self.interpolation = interpolation
163 | self.align_corners = align_corners
164 | self.original_size = original_size
165 |
166 | super().__init__("size", sizes)
167 |
168 | def apply_aug_image(self, image, size, **kwargs):
169 | if size != self.original_size:
170 | image = F.resize(
171 | image,
172 | size,
173 | interpolation=self.interpolation,
174 | align_corners=self.align_corners,
175 | )
176 | return image
177 |
178 | def apply_deaug_mask(self, mask, size, **kwargs):
179 | if self.original_size is None:
180 | raise ValueError(
181 | "Provide original image size to make mask backward transformation"
182 | )
183 | if size != self.original_size:
184 | mask = F.resize(
185 | mask,
186 | self.original_size,
187 | interpolation=self.interpolation,
188 | align_corners=self.align_corners,
189 | )
190 | return mask
191 |
192 | def apply_deaug_label(self, label, size=1, **kwargs):
193 | return label
194 |
195 | def apply_deaug_keypoints(self, keypoints, size=1, **kwargs):
196 | return keypoints
197 |
198 |
199 | class Add(ImageOnlyTransform):
200 | """Add value to images
201 |
202 | Args:
203 | values (List[float]): values to add to each pixel
204 | """
205 |
206 | identity_param = 0
207 |
208 | def __init__(self, values: List[float]):
209 |
210 | if self.identity_param not in values:
211 | values = [self.identity_param] + list(values)
212 | super().__init__("value", values)
213 |
214 | def apply_aug_image(self, image, value=0, **kwargs):
215 | if value != self.identity_param:
216 | image = F.add(image, value)
217 | return image
218 |
219 |
220 | class Multiply(ImageOnlyTransform):
221 | """Multiply images by factor
222 |
223 | Args:
224 | factors (List[float]): factor to multiply each pixel by
225 | """
226 |
227 | identity_param = 1
228 |
229 | def __init__(self, factors: List[float]):
230 | if self.identity_param not in factors:
231 | factors = [self.identity_param] + list(factors)
232 | super().__init__("factor", factors)
233 |
234 | def apply_aug_image(self, image, factor=1, **kwargs):
235 | if factor != self.identity_param:
236 | image = F.multiply(image, factor)
237 | return image
238 |
239 |
240 | class FiveCrops(ImageOnlyTransform):
241 | """Makes 4 crops for each corner + center crop
242 |
243 | Args:
244 | crop_height (int): crop height in pixels
245 | crop_width (int): crop width in pixels
246 | """
247 |
248 | def __init__(self, crop_height, crop_width):
249 | crop_functions = (
250 | partial(F.crop_lt, crop_h=crop_height, crop_w=crop_width),
251 | partial(F.crop_lb, crop_h=crop_height, crop_w=crop_width),
252 | partial(F.crop_rb, crop_h=crop_height, crop_w=crop_width),
253 | partial(F.crop_rt, crop_h=crop_height, crop_w=crop_width),
254 | partial(F.center_crop, crop_h=crop_height, crop_w=crop_width),
255 | )
256 | super().__init__("crop_fn", crop_functions)
257 |
258 | def apply_aug_image(self, image, crop_fn=None, **kwargs):
259 | return crop_fn(image)
260 |
261 | def apply_deaug_mask(self, mask, **kwargs):
262 | raise ValueError("`FiveCrop` augmentation is not suitable for mask!")
263 |
264 | def apply_deaug_keypoints(self, keypoints, **kwargs):
265 | raise ValueError("`FiveCrop` augmentation is not suitable for keypoints!")
266 |
--------------------------------------------------------------------------------
/ttach/wrappers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import Optional, Mapping, Union, Tuple
4 |
5 | from .base import Merger, Compose
6 |
7 |
8 | class SegmentationTTAWrapper(nn.Module):
9 | """Wrap PyTorch nn.Module (segmentation model) with test time augmentation transforms
10 |
11 | Args:
12 | model (torch.nn.Module): segmentation model with single input and single output
13 | (.forward(x) should return either torch.Tensor or Mapping[str, torch.Tensor])
14 | transforms (ttach.Compose): composition of test time transforms
15 | merge_mode (str): method to merge augmented predictions mean/gmean/max/min/sum/tsharpen
16 | output_mask_key (str): if model output is `dict`, specify which key belong to `mask`
17 | """
18 |
19 | def __init__(
20 | self,
21 | model: nn.Module,
22 | transforms: Compose,
23 | merge_mode: str = "mean",
24 | output_mask_key: Optional[str] = None,
25 | ):
26 | super().__init__()
27 | self.model = model
28 | self.transforms = transforms
29 | self.merge_mode = merge_mode
30 | self.output_key = output_mask_key
31 |
32 | def forward(
33 | self, image: torch.Tensor, *args
34 | ) -> Union[torch.Tensor, Mapping[str, torch.Tensor]]:
35 | merger = Merger(type=self.merge_mode, n=len(self.transforms))
36 |
37 | for transformer in self.transforms:
38 | augmented_image = transformer.augment_image(image)
39 | augmented_output = self.model(augmented_image, *args)
40 | if self.output_key is not None:
41 | augmented_output = augmented_output[self.output_key]
42 | deaugmented_output = transformer.deaugment_mask(augmented_output)
43 | merger.append(deaugmented_output)
44 |
45 | result = merger.result
46 | if self.output_key is not None:
47 | result = {self.output_key: result}
48 |
49 | return result
50 |
51 |
52 | class ClassificationTTAWrapper(nn.Module):
53 | """Wrap PyTorch nn.Module (classification model) with test time augmentation transforms
54 |
55 | Args:
56 | model (torch.nn.Module): classification model with single input and single output
57 | (.forward(x) should return either torch.Tensor or Mapping[str, torch.Tensor])
58 | transforms (ttach.Compose): composition of test time transforms
59 | merge_mode (str): method to merge augmented predictions mean/gmean/max/min/sum/tsharpen
60 | output_label_key (str): if model output is `dict`, specify which key belong to `label`
61 | """
62 |
63 | def __init__(
64 | self,
65 | model: nn.Module,
66 | transforms: Compose,
67 | merge_mode: str = "mean",
68 | output_label_key: Optional[str] = None,
69 | ):
70 | super().__init__()
71 | self.model = model
72 | self.transforms = transforms
73 | self.merge_mode = merge_mode
74 | self.output_key = output_label_key
75 |
76 | def forward(
77 | self, image: torch.Tensor, *args
78 | ) -> Union[torch.Tensor, Mapping[str, torch.Tensor]]:
79 | merger = Merger(type=self.merge_mode, n=len(self.transforms))
80 |
81 | for transformer in self.transforms:
82 | augmented_image = transformer.augment_image(image)
83 | augmented_output = self.model(augmented_image, *args)
84 | if self.output_key is not None:
85 | augmented_output = augmented_output[self.output_key]
86 | deaugmented_output = transformer.deaugment_label(augmented_output)
87 | merger.append(deaugmented_output)
88 |
89 | result = merger.result
90 | if self.output_key is not None:
91 | result = {self.output_key: result}
92 |
93 | return result
94 |
95 |
96 | class KeypointsTTAWrapper(nn.Module):
97 | """Wrap PyTorch nn.Module (keypoints model) with test time augmentation transforms
98 |
99 | Args:
100 | model (torch.nn.Module): keypoints model with single input and single output
101 | in format [x1,y1, x2, y2, ..., xn, yn]
102 | (.forward(x) should return either torch.Tensor or Mapping[str, torch.Tensor])
103 | transforms (ttach.Compose): composition of test time transforms
104 | merge_mode (str): method to merge augmented predictions mean/gmean/max/min/sum/tsharpen
105 | output_keypoints_key (str): if model output is `dict`, specify which key belong to `label`
106 | scaled (bool): True if model return x, y scaled values in [0, 1], else False
107 |
108 | """
109 |
110 | def __init__(
111 | self,
112 | model: nn.Module,
113 | transforms: Compose,
114 | merge_mode: str = "mean",
115 | output_keypoints_key: Optional[str] = None,
116 | scaled: bool = False,
117 | ):
118 | super().__init__()
119 | self.model = model
120 | self.transforms = transforms
121 | self.merge_mode = merge_mode
122 | self.output_key = output_keypoints_key
123 | self.scaled = scaled
124 |
125 | def forward(
126 | self, image: torch.Tensor, *args
127 | ) -> Union[torch.Tensor, Mapping[str, torch.Tensor]]:
128 | merger = Merger(type=self.merge_mode, n=len(self.transforms))
129 | size = image.size()
130 | batch_size, image_height, image_width = size[0], size[2], size[3]
131 |
132 | for transformer in self.transforms:
133 | augmented_image = transformer.augment_image(image)
134 | augmented_output = self.model(augmented_image, *args)
135 |
136 | if self.output_key is not None:
137 | augmented_output = augmented_output[self.output_key]
138 |
139 | augmented_output = augmented_output.reshape(batch_size, -1, 2)
140 | if not self.scaled:
141 | augmented_output[..., 0] /= image_width
142 | augmented_output[..., 1] /= image_height
143 |
144 | deaugmented_output = transformer.deaugment_keypoints(augmented_output)
145 | merger.append(deaugmented_output)
146 |
147 | result = merger.result
148 |
149 | if not self.scaled:
150 | result[..., 0] *= image_width
151 | result[..., 1] *= image_height
152 | result = result.reshape(batch_size, -1)
153 |
154 | if self.output_key is not None:
155 | result = {self.output_key: result}
156 |
157 | return result
158 |
--------------------------------------------------------------------------------