├── .github
└── workflows
│ ├── sphinx.yml
│ └── testing.yml
├── .gitignore
├── ACKNOWLEDGEMENTS
├── CITATION.cff
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs
├── Makefile
└── source
│ ├── conf.py
│ ├── index.rst
│ ├── python
│ ├── core.rst
│ ├── transforms.rst
│ ├── utils.rst
│ └── wrappers.rst
│ └── tutorials
│ ├── 000-About-the-Package.md
│ ├── 001-The-Structure-of-Parametrized-Transforms.md
│ ├── 002-How-to-Write-Your-Own-Transforms.md
│ ├── 003-A-Brief-Introduction-to-the-Transforms-in-This-Package.md
│ ├── 004-Parametrized-Transforms-in-Action.md
│ ├── 005-Migrate-To-and-From-torch-in-Three-Easy-Steps.md
│ ├── 999-In-a-Nutshell.md
│ └── assets
│ ├── 004-crazytx-cat-aug-1.jpeg
│ ├── 004-crazytx-cat-aug-2.jpeg
│ ├── 004-crazytx-cat-orig.jpeg
│ ├── 004-pnrce-cat-aug-1.jpeg
│ ├── 004-pnrce-cat-aug-2.jpeg
│ ├── 004-pnrce-cat-orig.jpeg
│ ├── cat-aug-0.jpeg
│ ├── cat-aug-1.jpeg
│ ├── cat-aug-2.jpeg
│ ├── cat-aug-3.jpeg
│ ├── cat-aug-4.jpeg
│ ├── cat-aug-5.jpeg
│ ├── cat.jpeg
│ ├── dog-aug-0.jpeg
│ ├── dog-aug-1.jpeg
│ ├── dog-aug-2.jpeg
│ ├── dog-aug-3.jpeg
│ ├── dog-aug-4.jpeg
│ ├── dog-aug-5.jpeg
│ ├── dog-aug-6.jpeg
│ ├── dog-aug-7.jpeg
│ ├── dog.jpeg
│ ├── no-dog-aug-0.jpeg
│ ├── no-dog-aug-1.jpeg
│ ├── no-dog-aug-2.jpeg
│ └── no-dog-aug-3.jpeg
├── examples
├── 002-001-RandomColorErasing.ipynb
├── 002-002-RandomSubsetApply.ipynb
├── 002-003-Visualizations.ipynb
├── 004-001-Playground.ipynb
└── 005-001-Torch-Torchvision-to-Parametrized-Transforms.ipynb
├── parameterized_transforms
├── __init__.py
├── core.py
├── transforms.py
├── utils.py
└── wrappers.py
├── pyproject.toml
├── requirements.txt
└── tests
├── test_atomic_transforms
├── __init__.py
├── test_CenterCrop.py
├── test_ColorJitter.py
├── test_ConvertImageDtype.py
├── test_ElasticTransform.py
├── test_FiveCrop.py
├── test_GaussianBlur.py
├── test_Grayscale.py
├── test_Lambda.py
├── test_LinearTransformation.py
├── test_Normalize.py
├── test_PILToTensor.py
├── test_Pad.py
├── test_RandomAdjustSharpness.py
├── test_RandomAffine.py
├── test_RandomAutocontrast.py
├── test_RandomCrop.py
├── test_RandomEqualize.py
├── test_RandomErasing.py
├── test_RandomGrayscale.py
├── test_RandomHorizontalFlip.py
├── test_RandomInvert.py
├── test_RandomPerspective.py
├── test_RandomPosterize.py
├── test_RandomResizedCrop.py
├── test_RandomRotation.py
├── test_RandomSolarize.py
├── test_RandomVerticalFlip.py
├── test_Resize.py
├── test_TenCrop.py
├── test_ToPILImage.py
└── test_ToTensor.py
├── test_composing_transforms
├── __init__.py
├── test_Compose.py
├── test_RandomApply.py
├── test_RandomChoice.py
└── test_RandomOrder.py
└── test_functions.py
/.github/workflows/sphinx.yml:
--------------------------------------------------------------------------------
1 | name: Deploy sphinx site to Pages
2 |
3 | on:
4 | push:
5 | branches: ["main"]
6 | workflow_dispatch:
7 |
8 | permissions:
9 | contents: read
10 | pages: write
11 | id-token: write
12 |
13 | concurrency:
14 | group: "pages"
15 | cancel-in-progress: false
16 |
17 | jobs:
18 | build:
19 | runs-on: ubuntu-latest
20 | permissions:
21 | contents: write
22 | steps:
23 | - uses: actions/checkout@v4
24 | - uses: actions/setup-python@v5
25 | with:
26 | python-version: '3.9'
27 | - name: Build with sphinx
28 | run: |
29 | pip install -e .
30 | cd docs
31 | pip install sphinx sphinx-book-theme myst-parser
32 | make html
33 | - name: Upload artifact
34 | uses: actions/upload-pages-artifact@v3
35 | with:
36 | path: ./docs/build/html
37 |
38 | deploy:
39 | environment:
40 | name: github-pages
41 | url: ${{ steps.deployment.outputs.page_url }}
42 | runs-on: ubuntu-latest
43 | needs: build
44 | steps:
45 | - name: Deploy to GitHub Pages
46 | id: deployment
47 | uses: actions/deploy-pages@v4
--------------------------------------------------------------------------------
/.github/workflows/testing.yml:
--------------------------------------------------------------------------------
1 | name: Run tests upon pull request events
2 |
3 | on:
4 | pull_request:
5 | branches: ["main"]
6 |
7 | jobs:
8 | pre-merge:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v4
12 | - uses: actions/setup-python@v5
13 | with:
14 | python-version: '3.9'
15 | - name: Build package and test dependencies to run unittests suite using pytest
16 | run: |
17 | pip install -e .
18 | pip install -e '.[test]'
19 | pytest
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # IDE related
132 | .idea
133 |
134 | # Testing related
135 | data
136 | .DS_Store
137 |
--------------------------------------------------------------------------------
/ACKNOWLEDGEMENTS:
--------------------------------------------------------------------------------
1 | Acknowledgements
2 | Portions of this Parameterized Transforms Software may utilize the following copyrighted
3 | material, the use of which is hereby acknowledged.
4 |
5 | Note:
6 | The use of code from the below mentioned repository(ies) is to ensure performance parity.
7 |
8 | _____________________
9 |
10 | Soumith Chintala (pytorch/vision)
11 |
12 | BSD 3-Clause License
13 |
14 | Copyright (c) Soumith Chintala 2016,
15 | All rights reserved.
16 |
17 | Redistribution and use in source and binary forms, with or without
18 | modification, are permitted provided that the following conditions are met:
19 |
20 | * Redistributions of source code must retain the above copyright notice, this
21 | list of conditions and the following disclaimer.
22 |
23 | * Redistributions in binary form must reproduce the above copyright notice,
24 | this list of conditions and the following disclaimer in the documentation
25 | and/or other materials provided with the distribution.
26 |
27 | * Neither the name of the copyright holder nor the names of its
28 | contributors may be used to endorse or promote products derived from
29 | this software without specific prior written permission.
30 |
31 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
32 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
33 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
34 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
35 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
36 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
37 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
38 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
39 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
40 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: "Dhekane"
5 | given-names: "Eeshan Gunesh"
6 | orcid: "https://orcid.org/0009-0006-3026-6258"
7 | title: "Parameterized Transforms"
8 | version: 1.0.0
9 | date-released: 2025-02-15
10 | url: "https://github.com/apple/parameterized-transforms"
11 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guide
2 |
3 | Thanks for your interest in contributing. In this project, there are limited plans for future development of the repository.
4 |
5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
6 |
7 | ## Before you get started
8 |
9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
10 |
11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
12 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Parameterized Transforms
2 |
3 |
4 | ## Index
5 | 1. [About the Package](#about-the-package)
6 | 2. [Installation](#installation)
7 | 3. [Getting Started](#getting-started)
8 |
9 |
10 |
11 | ## About the Package
12 | * The package provides a uniform, modular, and easily extendable implementation of `torchvision`-based transforms that provides access to their parameterization.
13 | * With this access, the transforms enable users to achieve the following two important functionalities--
14 | * Given an image, the transform can return an augmentation along with the parameters used for the augmentation.
15 | * Given an image and augmentation parameters, the transform can return the corresponding augmentation.
16 |
17 |
18 |
19 | ## Installation
20 | - To install the package directly, run the following commands:
21 | ```
22 | git clone https://github.com/apple/parameterized-transforms
23 | cd parameterized-transforms
24 | pip install -e .
25 | ```
26 | - To install the package via `pip`, run the following command:
27 | ```
28 | pip install --upgrade https://github.com/apple/parameterized-transforms
29 | ```
30 | - If you want to run unit tests locally, run the following steps:
31 | ```
32 | git clone https://github.com/apple/parameterized-transforms
33 | cd parameterized-transforms
34 | pip install -e .
35 | pip install -e '.[test]'
36 | pytest
37 | ```
38 |
39 |
40 |
41 | ## Getting Started
42 | * To understand the structure of parameterized transforms and the details of the package, we recommend the reader to
43 | start with
44 | [The First Tutorial](https://apple.github.io/parameterized-transforms/tutorials/000-About-the-Package.html)
45 | of our
46 | [Tutorial Series](https://apple.github.io/parameterized-transforms/).
47 | * However, for a quick starter, check out [Parameterized Transforms in a Nutshell](https://apple.github.io/parameterized-transforms/tutorials/999-In-a-Nutshell.html).
48 |
49 | ---
50 |
51 | ## Acknowledgement
52 | In its development, this project received help from multiple researchers, engineers, and other contributors from Apple.
53 | Special thanks to: Tim Kolecke, Jason Ramapuram, Russ Webb, David Koski, Mike Drob, Megan Maher Welsh, Marco Cuturi Cameto,
54 | Dan Busbridge, Xavier Suau Cuadros, and Miguel Sarabia del Castillo.
55 |
56 | ## Citation
57 | If you find this package useful and want to cite our work, here is the citation:
58 | ```
59 | @software{Dhekane_Parameterized_Transforms_2025,
60 | author = {Dhekane, Eeshan Gunesh},
61 | month = {2},
62 | title = {{Parameterized Transforms}},
63 | url = {https://github.com/apple/parameterized-transforms},
64 | version = {1.0.0},
65 | year = {2025}
66 | }
67 | ```
68 |
69 | ---
70 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Project information -----------------------------------------------------
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8 |
9 | project = 'Parameterized Transforms'
10 | copyright = '2025, Apple Inc'
11 | author = 'Eeshan Gunesh Dhekane'
12 |
13 | # -- General configuration ---------------------------------------------------
14 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
15 |
16 | extensions = [
17 | 'myst_parser',
18 | 'sphinx.ext.autodoc',
19 | 'sphinx.ext.autosummary',
20 | 'sphinx.ext.intersphinx',
21 | 'sphinx.ext.githubpages',
22 | 'sphinx.ext.napoleon',
23 | ]
24 |
25 | myst_enable_extensions = ['html_image']
26 | source_suffix = ['.rst', '.md']
27 | autoclass_content = 'both'
28 | templates_path = ['_templates']
29 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
30 |
31 |
32 |
33 | # -- Options for HTML output -------------------------------------------------
34 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
35 |
36 | html_theme = 'sphinx_book_theme'
37 | html_title = 'Parameterized Transforms'
38 | html_theme_options = {
39 | 'repository_url': 'https://github.com/apple/parameterized-transforms',
40 | 'use_repository_button': True,
41 | 'navigation_with_keys': False,
42 | }
43 | html_static_path = ['_static']
44 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | Parameterized Transforms
2 | ========================
3 |
4 |
5 | About the Package
6 | -----------------
7 |
8 | * The package provides a uniform, modular, and easily extendable implementation of `torchvision`-based transforms that provides access to their parameterization.
9 |
10 | * With this access, the transforms enable users to achieve the following two important functionalities--
11 |
12 | * Given an image, return an augmentation along with the parameters used for the augmentation.
13 |
14 | * Given an image and augmentation parameters, return the corresponding augmentation.
15 |
16 |
17 | Installation
18 | ------------
19 | * To install the package directly, run the following commands:
20 |
21 | .. code-block:: bash
22 |
23 | git clone git@github.com:apple/parameterized-transforms.git
24 | cd parameterized-transforms
25 | pip install -e .
26 |
27 |
28 | * To install the package via `pip`, run the following command:
29 |
30 | .. code-block:: bash
31 |
32 | pip install --upgrade git+https://git@github.com:apple/parameterized-transforms.git
33 |
34 |
35 | Get Started
36 | -----------
37 | * To understand the structure of parameterized transforms and the details of the package, we recommend the reader to start with the :ref:`Tutorial Series `.
38 |
39 | * Otherwise, for a quick starter, check out :ref:`Parameterized Transforms in a Nutshell `.
40 |
41 |
42 | Important Links
43 | ---------------
44 |
45 | .. _Quick-Start-label:
46 |
47 | .. toctree::
48 | :caption: Quick-Start
49 | :maxdepth: 1
50 |
51 | tutorials/999-In-a-Nutshell
52 |
53 | .. _Tutorial-Series-label:
54 |
55 | .. toctree::
56 | :caption: Tutorial Series
57 | :maxdepth: 1
58 |
59 | tutorials/000-About-the-Package
60 | tutorials/001-The-Structure-of-Parametrized-Transforms
61 | tutorials/002-How-to-Write-Your-Own-Transforms
62 | tutorials/003-A-Brief-Introduction-to-the-Transforms-in-This-Package
63 | tutorials/004-Parametrized-Transforms-in-Action
64 | tutorials/005-Migrate-To-and-From-torch-in-Three-Easy-Steps
65 |
66 |
67 | .. toctree::
68 | :caption: Python API Reference
69 | :maxdepth: 1
70 |
71 | python/core
72 | python/transforms
73 | python/utils
74 | python/wrappers
75 |
--------------------------------------------------------------------------------
/docs/source/python/core.rst:
--------------------------------------------------------------------------------
1 | core
2 | ====
3 |
4 | .. automodule:: parameterized_transforms.core
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/python/transforms.rst:
--------------------------------------------------------------------------------
1 | transforms
2 | ==========
3 |
4 | .. automodule:: parameterized_transforms.transforms
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/python/utils.rst:
--------------------------------------------------------------------------------
1 | utils
2 | =====
3 |
4 | .. automodule:: parameterized_transforms.utils
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/python/wrappers.rst:
--------------------------------------------------------------------------------
1 | wrappers
2 | ========
3 |
4 | .. automodule:: parameterized_transforms.wrappers
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/tutorials/000-About-the-Package.md:
--------------------------------------------------------------------------------
1 | # About the Package
2 |
3 | - 5 Minute Read
4 |
5 |
6 | ## Summary
7 | * This tutorial describes the three important aspects of the `Parameterized Transforms` package:
8 | 1. **Why** do we need this package?,
9 | 2. **What** does the package provide?, and
10 | 3. **How** to use the package?
11 |
12 |
13 | * **NOTE:** We will be using the terms **Augmentation (noun) / Augment (verb)** interchangeably with
14 | **Transform (noun) / Transform (verb)** throughout the tutorials.
15 |
16 |
17 | ## The *Why* Aspect
18 | * Augmentation strategies are important in computer vision research for improving the performance of deep learning approaches.
19 | * Popular libraries like `torchvision` and `kornia` provide implementations of widely used and important transforms.
20 | * Many recent research ideas revolve around using the information of augmentation parameters in order to learn better representations.
21 | In this context, different popular libraries have different pros and cons:
22 | * For instance, most of recent deep learning approaches define their augmentation stacks in terms of the
23 | `torchvision`-based transforms, experiment with them, and report the best-performing stacks.
24 | However, `torchvision`-based transforms do NOT provide access to their parameters, thereby limiting the research possibilities aimed at extracting information provided by augmentation parameters to learn better data representations.
25 | * On the other hand, although `kornia`-based augmentation stacks do provide access to the parameters of the augmentations,
26 | reproducing results obtained with `torchvision` stacks using `kornia`-based augmentations is difficult due to the differences in their implementation.
27 | * Ideally, we want to have transforms implementations that have the following desired properties:
28 | 1. they can provide access to their parameters by exposing them,
29 | 2. they allow reproducible augmentations by enabling application of the transform defined by given parameters,
30 | 3. they are easy to subclass and extend in order to tweak their functionality, and
31 | 4. they have implementations that match those of the transforms used in obtaining the state-of-the-art results (mostly, `torchvision`).
32 | * This is very difficult to achieve with any of the currently existing libraries.
33 |
34 |
35 |
36 | ## The *What* Aspect
37 | * What this package provides is a modular, uniform, and easily extendable skeleton with a re-implementation of `torchvision`-based
38 | transforms that gives you access to their augmentation parameters and allows reproducible augmentations.
39 | * In particular, these transforms can perform two extremely crucial tasks associated with exposing their parameters:
40 | 1. Given an image, the transform can return an augmentation along with the parameters used for the augmentation.
41 | 2. Given an image and well-defined augmentation parameters, the transform can return the corresponding augmented image.
42 | * The uniform template for all transforms and a modular re-implementation means that you can easily subclass the
43 | transforms and tweak their functionalities.
44 | * In addition, you can write your own custom transforms using the provided templates and combine them seamlessly
45 | with other custom or package-defined transforms for your experimentation.
46 |
47 |
48 |
49 | ## The *How* Aspect
50 | * To start using the package, we recommend the following--
51 | 1. Read through the [Prerequisites](#prerequisites) listed below and be well-acquainted with them.
52 | 2. [Install the Package](https://github.com/apple/parameterized-transforms/blob/main/README.md#installation) as described in the link.
53 | 3. Read through the [Tutorial Series](#tutorials-in-a-nutshell).
54 | 4. After that, you should be ready to write and experiment with parameterized transforms!
55 |
56 |
57 |
58 | ## Prerequisites
59 | Here are the prerequisites for this package--
60 | * `numpy`: being comfortable with `numpy` arrays and operations,
61 | * `PIL`: being comfortable with basic `PIL` operations and the `PIL.Image.Image` class,
62 | * `torch`: being comfortable with `torch` tensors and operations, and
63 | * `torchvision`: being comfortable with `torchvision` transforms and operations.
64 |
65 |
66 |
67 | ## A Preview of All Tutorials
68 | * Here is an overview of the tutorials in this series and the topics they cover--
69 |
70 | | Title | Contents |
71 | |-------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
72 | | [0. About the Package](000-About-the-Package.md) | An overview of the package |
73 | | [1. The Structure of Parametrized Transforms](001-The-Structure-of-Parametrized-Transforms.md) | Explanation of the base classes `Transform`, `AtomicTransform`, and `ComposingTransform` |
74 | | [2. How to Write Your Own Transforms](002-How-to-Write-Your-Own-Transforms.md) | A walk-through of writing custom transforms-- an atomic transform named `RandomColorErasing` and a composing transform named `RandomSubsetApply` |
75 | | [3. A Brief Introduction to the Transforms in This Package](003-A-Brief-Introduction-to-the-Transforms-in-This-Package.md) | * Information about all the transforms provided in this package
* Demonstrations of some of the important transforms from the package |
76 | | [4. Parametrized Transforms in Action](004-Parametrized-Transforms-in-Action.md) | * Visualization of augmentations produced by the custom transforms
* Combining custom transforms with the ones from the package
* Extending transforms easily to tweak their behaviors |
77 | | [5. Migrate From `torch`/`torchvision` to `parameterized_transforms` in Three Easy Steps](005-Migrate-To-and-From-torch-in-Three-Easy-Steps.md) | Instructions to easily modify code with `torch`-based datasets/loaders and `torchvision`-based transforms to use the parameterized transforms |
78 |
79 |
80 |
81 | ## Credits
82 | In case you find our work useful in your research, you can use the following `bibtex` entry to cite us--
83 | ```text
84 | @software{Dhekane_Parameterized_Transforms_2025,
85 | author = {Dhekane, Eeshan Gunesh},
86 | month = {2},
87 | title = {{Parameterized Transforms}},
88 | url = {https://github.com/apple/parameterized-transforms},
89 | version = {1.0.0},
90 | year = {2025}
91 | }
92 | ```
93 |
94 |
95 |
96 | ## About the Next Tutorial
97 | * In the next tutorial [001-The-Structure-of-Parametrized-Transforms.md](001-The-Structure-of-Parametrized-Transforms.md), we will describe the core structure of parameterized transforms.
98 | * We will see two different types of transforms, **Atomic** and **Composing**, and describe their details.
99 |
--------------------------------------------------------------------------------
/docs/source/tutorials/999-In-a-Nutshell.md:
--------------------------------------------------------------------------------
1 | # Parameterized Transforms in a Nutshell
2 |
3 |
4 | 1. Parameterized transforms input an image and a tuple of parameters to produce augmentation and modified parameters.
5 | The mode of any parameterized transform can be `CASCADE` or `CONSUME`.
6 | In `CASCADE` mode, the transform generates an augmentation and appends parameters used in this augmentation to the input parameters.
7 | In `CONSUME` mode, the transform removes from the beginning of the tuple the parameters it needs and generates the corresponding augmentation.
8 | The augmentation and the modified parameters are then returned.
9 | With these two modes, we can have reproducible, flexible augmentations and much more.
10 | ```python
11 | # Signature of parameterized transforms.
12 | import parameterized_transforms.transforms as ptx
13 | import parameterized_transforms.core as ptc
14 |
15 | tx1 = ptx.RandomRotation(degrees=45) # Default mode: CASCADE.
16 | params1 = (3, 2, 0, 1, 0.3, -2.5) # Example: Parameters from previous parameterized transforms.
17 |
18 | augmentation1, modified_params = tx1(image, params1)
19 | # ALTERNATIVELY: augmentation1, modified_params = tx1.cascade_transform(image, params1)
20 | # augmentation: Image rotated by a random angle, say 31.25 degrees.
21 | # modified_params: (3, 2, 0, 1, 0.3, -2.5, 31.25). Note the appended 31.25 angle value.
22 |
23 | tx2 = ptx.RandomRotation(degrees=45, tx_mode=ptc.TransformMode.CONSUME)
24 | params2 = (31.25, 0, 1, 0.5, -1.7) # Example: Parameter for `RandomRotation` (31.25) and possibly other parameterized transforms.
25 | augmentation2, remaining_params = tx2(image, params2)
26 | # ALTERNATIVELY: tx1.consume_transform(image, params2) or tx2.consume_transform(image, params2)
27 | # augmentation2: The same augmentation as augmentation1 above.
28 | # remaining_params: (0, 1, 0.5, -1.7).
29 | ```
30 |
31 | 2. Parameterized versions of all `torchvision`-based transforms are supported.
32 | ```python
33 | # Example
34 | import parameterized_transforms.transforms as ptx
35 |
36 | tx = ptx.Compose(
37 | [
38 | ptx.RandomHorizontalFlip(p=0.5),
39 | ptx.RandomApply(
40 | [
41 | ptx.ColorJitter(
42 | brightness=0.1,
43 | contrast=0.1,
44 | saturation=0.1,
45 | hue=0.1,
46 | )
47 | ],
48 | p=0.5,
49 | ),
50 | ptx.ToTensor(),
51 | ptx.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
52 | ]
53 | )
54 |
55 | augmentation, params = tx(image) # Default params: ().
56 | augmentation_1, empty_params = tx.consume_transform(image, params)
57 | # augmentation_1: Same as the augmentation above.
58 | # empty_params: (), as all params are extracted and used.
59 | ```
60 |
61 | 3. You can [write your own transforms](https://apple.github.io/parameterized-transforms/tutorials/002-How-to-Write-Your-Own-Transforms.html) that adhere to the structure of parameterized transforms.
62 | Then, your transforms will work seamlessly with those from the package!
63 | ```python
64 | import parameterized_transforms.transforms as ptx
65 |
66 | tx = RandomSubsetApply( # Your custom transform
67 | [
68 | RandomColorErasing(), # Your custom transform
69 | ptx.RandomHorizontalFlip(p=0.5),
70 | ptx.RandomApply(
71 | [
72 | ptx.ColorJitter(
73 | brightness=0.1,
74 | contrast=0.1,
75 | saturation=0.1,
76 | hue=0.1,
77 | )
78 | ],
79 | p=0.5,
80 | )
81 | ]
82 | )
83 |
84 | augmentation, params = tx(image) # Default params: ().
85 | augmentation_1, empty_params = tx.consume_transform(image, params)
86 | # augmentation_1: Same as the augmentation above.
87 | # empty_params: (), as all params are extracted and used.
88 | ```
89 |
90 | 4. You can use parameterized transforms with `torch`/`torchvision` dataset directly.
91 | However, in order to have parameters output as a single tensor of shape `[batch_size=B, num_params=P]`, we recommend wrapping your transform in `CastParamsToTensor` wrapper.
92 | More on this in [tutorial on working with torch/torchvision](https://apple.github.io/parameterized-transforms/tutorials/002-How-to-Write-Your-Own-Transforms.html).
93 |
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/004-crazytx-cat-aug-1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/004-crazytx-cat-aug-1.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/004-crazytx-cat-aug-2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/004-crazytx-cat-aug-2.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/004-crazytx-cat-orig.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/004-crazytx-cat-orig.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/004-pnrce-cat-aug-1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/004-pnrce-cat-aug-1.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/004-pnrce-cat-aug-2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/004-pnrce-cat-aug-2.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/004-pnrce-cat-orig.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/004-pnrce-cat-orig.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/cat-aug-0.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/cat-aug-0.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/cat-aug-1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/cat-aug-1.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/cat-aug-2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/cat-aug-2.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/cat-aug-3.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/cat-aug-3.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/cat-aug-4.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/cat-aug-4.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/cat-aug-5.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/cat-aug-5.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/cat.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/cat.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/dog-aug-0.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/dog-aug-0.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/dog-aug-1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/dog-aug-1.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/dog-aug-2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/dog-aug-2.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/dog-aug-3.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/dog-aug-3.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/dog-aug-4.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/dog-aug-4.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/dog-aug-5.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/dog-aug-5.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/dog-aug-6.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/dog-aug-6.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/dog-aug-7.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/dog-aug-7.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/dog.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/dog.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/no-dog-aug-0.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/no-dog-aug-0.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/no-dog-aug-1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/no-dog-aug-1.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/no-dog-aug-2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/no-dog-aug-2.jpeg
--------------------------------------------------------------------------------
/docs/source/tutorials/assets/no-dog-aug-3.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/no-dog-aug-3.jpeg
--------------------------------------------------------------------------------
/parameterized_transforms/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | from parameterized_transforms import core
19 | from parameterized_transforms import transforms
20 | from parameterized_transforms import utils
21 |
22 |
23 | ATOMIC_TRANSFORMS = {
24 | "ToTensor": transforms.ToTensor,
25 | "PILToTensor": transforms.PILToTensor,
26 | "ConvertImageDtype": transforms.ConvertImageDtype,
27 | "ToPILImage": transforms.ToPILImage,
28 | "Normalize": transforms.Normalize,
29 | "Resize": transforms.Resize,
30 | "CenterCrop": transforms.CenterCrop,
31 | "Pad": transforms.Pad,
32 | "Lambda": transforms.Lambda,
33 | "RandomCrop": transforms.RandomCrop,
34 | "RandomHorizontalFlip": transforms.RandomHorizontalFlip,
35 | "RandomVerticalFlip": transforms.RandomVerticalFlip,
36 | "RandomPerspective": transforms.RandomPerspective,
37 | "RandomResizedCrop": transforms.RandomResizedCrop,
38 | "FiveCrop": transforms.FiveCrop,
39 | "TenCrop": transforms.TenCrop,
40 | "LinearTransformation": transforms.LinearTransformation,
41 | "ColorJitter": transforms.ColorJitter,
42 | "RandomRotation": transforms.RandomRotation,
43 | "RandomAffine": transforms.RandomAffine,
44 | "Grayscale": transforms.Grayscale,
45 | "RandomGrayscale": transforms.RandomGrayscale,
46 | "RandomErasing": transforms.RandomErasing,
47 | "GaussianBlur": transforms.GaussianBlur,
48 | "RandomInvert": transforms.RandomInvert,
49 | "RandomPosterize": transforms.RandomPosterize,
50 | "RandomSolarize": transforms.RandomSolarize,
51 | "RandomAdjustSharpness": transforms.RandomAdjustSharpness,
52 | "RandomAutocontrast": transforms.RandomAutocontrast,
53 | "RandomEqualize": transforms.RandomEqualize,
54 | "ElasticTransform": transforms.ElasticTransform,
55 | }
56 |
57 |
58 | COMPOSING_TRANSFORMS = {
59 | "Compose": transforms.Compose,
60 | "RandomApply": transforms.RandomApply,
61 | "RandomOrder": transforms.RandomOrder,
62 | "RandomChoice": transforms.RandomChoice,
63 | }
64 |
65 |
66 | TRANSFORMS = dict(
67 | tuple(ATOMIC_TRANSFORMS.items())
68 | + tuple(COMPOSING_TRANSFORMS.items())
69 | )
70 |
--------------------------------------------------------------------------------
/parameterized_transforms/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | import enum
19 | import typing as t
20 |
21 | from parameterized_transforms.core import SCALAR_TYPE
22 |
23 | SCALAR_RANGE_TYPE = t.Union[t.List[SCALAR_TYPE], t.Tuple[SCALAR_TYPE, SCALAR_TYPE]]
24 |
25 |
26 | # --------------------------------------------------------------------------------
27 | # Transforms-related functionality.
28 | # --------------------------------------------------------------------------------
29 |
30 |
31 | def get_total_params_count(transforms: t.Any) -> int:
32 | """Returns the total number of processed parameters for the given
33 | collection of core transformations.
34 |
35 | :param transforms: The collection of core transforms.
36 |
37 | :return: The total number of processed parameters.
38 | """
39 | param_count: int = 0
40 |
41 | if isinstance(transforms, list) or isinstance(transforms, tuple):
42 |
43 | for transform in transforms:
44 |
45 | try:
46 | transform_param_count = transform.param_count
47 | except Exception as e:
48 | raise AttributeError(
49 | "ERROR | `param_count` access "
50 | f"for transform: {transform} failed; "
51 | f"hit error:\n{e}"
52 | )
53 |
54 | param_count += transform_param_count
55 |
56 | elif isinstance(transforms, dict):
57 |
58 | for _, transform in transforms.items():
59 |
60 | try:
61 | transform_param_count = transform.param_count
62 | except Exception as e:
63 | raise AttributeError(
64 | f"ERROR | `param_count` access "
65 | f"for transforms: {transform} failed; "
66 | f"hit error:\n{e}"
67 | )
68 |
69 | param_count += transform_param_count
70 |
71 | else:
72 |
73 | try:
74 | param_count += transforms.param_count
75 | except Exception as e:
76 | raise AttributeError(
77 | "ERROR | `param_count` access"
78 | f"for transforms: {transforms} failed; "
79 | f"hit error:\n{e}"
80 | )
81 |
82 | return param_count
83 |
84 |
85 | # --------------------------------------------------------------------------------
86 | # String representation manipulations.
87 | # --------------------------------------------------------------------------------
88 |
89 |
90 | def indent(data: str, indentor: str = " ", connector: str = "\n") -> str:
91 | """Indents the given string data with given `indentor`.
92 |
93 | :param data: The data to indent.
94 | :param indentor: The indenting character.
95 | DEFAULT: `" "` (two spaces).
96 | :param connector: The string used to concatenate given components.
97 | DEFAULT: `"\n"`.
98 |
99 | :return: The indented data.
100 |
101 | Careful not to have new-line characters in the `indentor`. This is
102 | currently allowed but NOT tested, use at your own risk.
103 | """
104 | lines = string_to_components(string=data, separator=connector)
105 | indented_lines = indent_lines(lines=lines, indentor=indentor)
106 | indented_data = components_to_string(components=indented_lines, connector=connector)
107 |
108 | return indented_data
109 |
110 |
111 | def indent_lines(lines: t.List[str], indentor: str = " ") -> t.List[str]:
112 | """Adds as prefix the `indentor` string to each of the given lines.
113 |
114 | :param lines: The lines of the content.
115 | :param indentor: The indenting character.
116 | DEFAULT: `" "` (two spaces).
117 |
118 | :return: The indented lines.
119 |
120 | Careful not to have new-line characters in the `indentor`. This is
121 | currently allowed but NOT tested, use at your own risk.
122 | """
123 |
124 | return [indent_line(line=line, indentor=indentor) for line in lines]
125 |
126 |
127 | def indent_line(line: str, indentor: str = " ") -> str:
128 | """Indents the given `string` using the given `indentor`.
129 |
130 | :param line: The string to be indented.
131 | :param indentor: The indenting string.
132 |
133 | :returns: The indented string.
134 | """
135 |
136 | return f"{indentor}{line}"
137 |
138 |
139 | def string_to_components(string: str, separator: str = "\n") -> t.List[str]:
140 | """Split a given string into componenets based on given `separator`.
141 |
142 | :param string: The input string.
143 | :param separator: The string to split the given string into components.
144 |
145 | :returns: The components extracted from this string.
146 | """
147 |
148 | return string.split(separator)
149 |
150 |
151 | def components_to_string(components: t.List[str], connector: str = ",\n") -> str:
152 | """Concatenate given components into a string using given `connector`.
153 |
154 | :param components: The components of a string.
155 | :param connector: The string used to concatenate given components.
156 |
157 | :returns: The concatenated string made from the given `components` using
158 | the given `connector` string.
159 | """
160 | return connector.join(components)
161 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Apple Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | # Metadata
17 | # ============================================================================
18 | [project]
19 |
20 | name = "parameterized_transforms"
21 | version = "1.0.0"
22 | description = "Transforms that provide reproducible access to parameterization information."
23 | authors = [
24 | {name="Eeshan Gunesh Dhekane", email="eeshangunesh_dhekane@apple.com"}
25 | ]
26 | license= {"text"="Apache License, Version 2.0"}
27 | readme = "README.md"
28 | requires-python = ">=3.9.6" # Improve this.
29 | dependencies = [
30 | "numpy",
31 | "torch",
32 | "torchvision"
33 | ]
34 |
35 |
36 | [tool.setuptools]
37 | py-modules = ["parameterized_transforms"]
38 |
39 |
40 | [project.urls]
41 | Repository = "https://github.com/apple/parameterized-transforms"
42 |
43 |
44 | [project.optional-dependencies]
45 | test = [
46 | "pytest",
47 | "pytest-xdist",
48 | "pytest-cov",
49 | "pytest-memray",
50 | "coverage[toml]"
51 | ]
52 |
53 |
54 | # Pytest options
55 | # ============================================================================
56 | [tool.pytest.ini_options]
57 | # This determines where test are found
58 | testpaths = ["tests/"]
59 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==2.0.2
2 | torch==2.6.0
3 | torchvision==0.21.0
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/tests/test_atomic_transforms/__init__.py
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_CenterCrop.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 |
25 | import numpy as np
26 |
27 | import PIL.Image as Image
28 | import PIL.ImageChops as ImageChops
29 |
30 | from itertools import product
31 |
32 | import typing as t
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "size, to_size, channels, tx_mode",
37 | list(
38 | product(
39 | [
40 | (
41 | 224,
42 | 224,
43 | ),
44 | (
45 | 32,
46 | 32,
47 | ),
48 | (
49 | 28,
50 | 28,
51 | ),
52 | ],
53 | [
54 | [224, 224],
55 | [32, 32],
56 | [28, 28],
57 | (224, 224),
58 | (32, 32),
59 | (28, 28),
60 | (224, 32),
61 | [32, 28],
62 | (224,),
63 | (32,),
64 | (28,),
65 | [
66 | 224,
67 | ],
68 | [
69 | 32,
70 | ],
71 | [
72 | 28,
73 | ],
74 | 224,
75 | 32,
76 | 28,
77 | ],
78 | [
79 | (None,),
80 | (3,),
81 | (4,),
82 | ],
83 | ["CASCADE", "CONSUME"],
84 | )
85 | ),
86 | )
87 | def test_tx_on_PIL_images(
88 | size: t.Tuple[int],
89 | to_size: t.Union[t.Tuple[int], t.List[int], int],
90 | channels: t.Tuple[int],
91 | tx_mode: str,
92 | ) -> None:
93 |
94 | if channels == (None,):
95 | channels = ()
96 |
97 | size = size + channels
98 |
99 | tx = ptx.CenterCrop(size=to_size, tx_mode=tx_mode)
100 |
101 | img = Image.fromarray(
102 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
103 | )
104 |
105 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
106 |
107 | aug_1, params_1 = tx(img, orig_params)
108 |
109 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
110 |
111 | aug_3, params_3 = tx.consume_transform(img, orig_params)
112 |
113 | assert orig_params == params_1
114 | assert orig_params == params_2
115 | assert orig_params == params_3
116 |
117 | assert not ImageChops.difference(aug_1, aug_2).getbbox()
118 | assert not ImageChops.difference(aug_1, aug_3).getbbox()
119 |
120 | id_params = tx.get_default_params(img=img, processed=True)
121 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
122 |
123 | if img.size == id_aug_img.size:
124 | assert id_aug_rem_params == ()
125 | assert not ImageChops.difference(id_aug_img, img).getbbox()
126 |
127 |
128 | @pytest.mark.parametrize(
129 | "size, to_size, channels, tx_mode",
130 | list(
131 | product(
132 | [
133 | (
134 | 224,
135 | 224,
136 | ),
137 | (
138 | 32,
139 | 32,
140 | ),
141 | (
142 | 28,
143 | 28,
144 | ),
145 | ],
146 | [
147 | [224, 224],
148 | [32, 32],
149 | [28, 28],
150 | (224, 224),
151 | (32, 32),
152 | (28, 28),
153 | (224, 32),
154 | [32, 28],
155 | (224,),
156 | (32,),
157 | (28,),
158 | [
159 | 224,
160 | ],
161 | [
162 | 32,
163 | ],
164 | [
165 | 28,
166 | ],
167 | 224,
168 | 32,
169 | 28,
170 | ],
171 | [
172 | (None,),
173 | (1,),
174 | (3,),
175 | (4,),
176 | ],
177 | ["CASCADE", "CONSUME"],
178 | )
179 | ),
180 | )
181 | def test_tx_on_torch_tensors(
182 | size: t.Tuple[int],
183 | to_size: t.Union[t.Tuple[int], t.List[int], int],
184 | channels: t.Tuple[int],
185 | tx_mode: str,
186 | ) -> None:
187 |
188 | if channels == (None,):
189 | channels = ()
190 |
191 | size = channels + size
192 |
193 | tx = ptx.CenterCrop(size=to_size, tx_mode=tx_mode)
194 |
195 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size))
196 |
197 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
198 |
199 | aug_1, params_1 = tx(img, orig_params)
200 |
201 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
202 |
203 | aug_3, params_3 = tx.consume_transform(img, orig_params)
204 |
205 | assert orig_params == params_1
206 | assert orig_params == params_2
207 | assert orig_params == params_3
208 |
209 | assert torch.all(torch.eq(aug_1, aug_2))
210 | assert torch.all(torch.eq(aug_1, aug_3))
211 |
212 | id_params = tx.get_default_params(img=img, processed=True)
213 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
214 |
215 | if img.shape == id_aug_img.shape:
216 | assert id_aug_rem_params == ()
217 | assert torch.all(torch.eq(id_aug_img, img))
218 |
219 |
220 | # Main.
221 | if __name__ == "__main__":
222 | pass
223 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_ColorJitter.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 | import parameterized_transforms.core as ptc
25 |
26 | import numpy as np
27 |
28 | import PIL.Image as Image
29 | import PIL.ImageChops as ImageChops
30 |
31 | from itertools import product
32 |
33 | import typing as t
34 |
35 |
36 | @pytest.mark.parametrize(
37 | "size, brightness, contrast, saturation, hue, channels, tx_mode, default_params_mode",
38 | list(
39 | product(
40 | [
41 | (
42 | 224,
43 | 224,
44 | ),
45 | ],
46 | [
47 | 0.5,
48 | [0.0, 3.0],
49 | ],
50 | [
51 | 0.5,
52 | [0.0, 3.0],
53 | ],
54 | [
55 | 0.5,
56 | [0.0, 3.0],
57 | ],
58 | [
59 | 0.247,
60 | [0.1, 0.3],
61 | ],
62 | [(3,), (4,), (None,)],
63 | ["CASCADE", "CONSUME"],
64 | [
65 | ptc.DefaultParamsMode.UNIQUE,
66 | ptc.DefaultParamsMode.RANDOMIZED,
67 | ],
68 | )
69 | ),
70 | )
71 | def test_tx_on_PIL_images(
72 | size: t.Tuple[int],
73 | brightness: t.Union[float, t.List[float], t.Tuple[float, float]],
74 | contrast: t.Union[float, t.List[float], t.Tuple[float, float]],
75 | saturation: t.Union[float, t.List[float], t.Tuple[float, float]],
76 | hue: t.Union[float, t.List[float], t.Tuple[float, float]],
77 | channels: t.Tuple[int],
78 | tx_mode: str,
79 | default_params_mode: ptc.DefaultParamsMode,
80 | ) -> None:
81 | if channels == (None,):
82 | channels = ()
83 |
84 | size = size + channels
85 |
86 | tx = ptx.ColorJitter(
87 | brightness=brightness,
88 | contrast=contrast,
89 | saturation=saturation,
90 | hue=hue,
91 | tx_mode=tx_mode,
92 | default_params_mode=default_params_mode,
93 | )
94 |
95 | img = Image.fromarray(
96 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
97 | )
98 |
99 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
100 |
101 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
102 | assert len(params_1) - len(orig_params) == tx.param_count
103 |
104 | orig_params = ()
105 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
106 | aug_3, params_3 = tx.consume_transform(img, params_2)
107 |
108 | assert orig_params == params_3
109 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
110 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
111 |
112 | id_params = tx.get_default_params(img=img, processed=True)
113 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
114 | assert id_aug_rem_params == ()
115 | assert not ImageChops.difference(id_aug_img, img).getbbox()
116 |
117 |
118 | @pytest.mark.parametrize(
119 | "size, brightness, contrast, saturation, hue, channels, tx_mode, default_params_mode",
120 | list(
121 | product(
122 | [
123 | (
124 | 224,
125 | 224,
126 | ),
127 | ],
128 | [
129 | 0.5,
130 | [0.0, 3.0],
131 | ],
132 | [
133 | 0.5,
134 | [0.0, 3.0],
135 | ],
136 | [
137 | 0.5,
138 | [0.0, 3.0],
139 | ],
140 | [
141 | 0.247,
142 | [0.1, 0.3],
143 | ],
144 | [
145 | (1,),
146 | (3,),
147 | # (4,), # This creates issues
148 | # (None,), # This creates issues
149 | ],
150 | ["CASCADE", "CONSUME"],
151 | [
152 | ptc.DefaultParamsMode.UNIQUE,
153 | ptc.DefaultParamsMode.RANDOMIZED,
154 | ],
155 | )
156 | ),
157 | )
158 | def test_tx_on_torch_tensors(
159 | size: t.Tuple[int],
160 | brightness: t.Union[float, t.List[float], t.Tuple[float, float]],
161 | contrast: t.Union[float, t.List[float], t.Tuple[float, float]],
162 | saturation: t.Union[float, t.List[float], t.Tuple[float, float]],
163 | hue: t.Union[float, t.List[float], t.Tuple[float, float]],
164 | channels: t.Tuple[int],
165 | tx_mode: str,
166 | default_params_mode: ptc.DefaultParamsMode,
167 | ) -> None:
168 | if channels == (None,):
169 | channels = ()
170 |
171 | size = channels + size
172 |
173 | tx = ptx.ColorJitter(
174 | brightness=brightness,
175 | contrast=contrast,
176 | saturation=saturation,
177 | hue=hue,
178 | tx_mode=tx_mode,
179 | default_params_mode=default_params_mode,
180 | )
181 |
182 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size))
183 |
184 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
185 |
186 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
187 | assert len(params_1) - len(orig_params) == tx.param_count
188 |
189 | orig_params = ()
190 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
191 | aug_3, params_3 = tx.consume_transform(img, params_2)
192 |
193 | assert orig_params == params_3
194 | assert torch.all(torch.eq(aug_2, aug_3))
195 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
196 |
197 | id_params = tx.get_default_params(img=img, processed=True)
198 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
199 | assert id_aug_rem_params == ()
200 | assert torch.all(torch.eq(id_aug_img, img))
201 |
202 |
203 | # Main.
204 | if __name__ == "__main__":
205 | pass
206 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_ConvertImageDtype.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 | import torch
21 |
22 | import parameterized_transforms.transforms as ptx
23 |
24 | import numpy as np
25 |
26 | import PIL.Image as Image
27 |
28 | from itertools import product
29 |
30 | import typing as t
31 |
32 |
33 | @pytest.mark.parametrize(
34 | "dtype, mode",
35 | list(
36 | product(
37 | [
38 | torch.int,
39 | torch.int8,
40 | torch.int16,
41 | torch.int32,
42 | torch.int64,
43 | torch.float,
44 | torch.float16,
45 | torch.float32,
46 | torch.float64,
47 | ],
48 | ["CASCADE", "CONSUME"],
49 | )
50 | ),
51 | )
52 | def test_tx_on_PIL_images(dtype: t.Any, mode: str) -> None:
53 |
54 | with pytest.raises(TypeError):
55 |
56 | tx = ptx.ConvertImageDtype(dtype=dtype, tx_mode=mode)
57 |
58 | img = Image.fromarray(
59 | np.random.randint(low=0, high=256, size=[224, 224, 3]).astype(np.uint8)
60 | )
61 |
62 | prev_params = (1, 0.0, -2.0, 3, -4)
63 |
64 | aug_img, params = tx(img, prev_params)
65 |
66 |
67 | @pytest.mark.parametrize(
68 | "size, channels, channel_first, high, high_type, to_dtype, tx_mode",
69 | list(
70 | product(
71 | [
72 | (
73 | 224,
74 | 224,
75 | ),
76 | ],
77 | [
78 | (None,),
79 | (1,),
80 | (3,),
81 | (4,),
82 | ],
83 | [True, False],
84 | [1.0, 256],
85 | [
86 | np.float16,
87 | np.float32,
88 | np.float64,
89 | np.uint8,
90 | np.int8,
91 | np.int16,
92 | np.int32,
93 | np.int64,
94 | ],
95 | [
96 | torch.int,
97 | torch.int8,
98 | torch.int16,
99 | torch.int32,
100 | torch.int64,
101 | torch.float,
102 | torch.float16,
103 | torch.float32,
104 | torch.float64,
105 | ],
106 | ["CASCADE", "CONSUME"],
107 | )
108 | ),
109 | )
110 | def test_tx_on_ndarrays(
111 | size: t.Tuple[int],
112 | channels: t.Tuple[int],
113 | channel_first: bool,
114 | high: t.Union,
115 | high_type: t.Any,
116 | to_dtype: t.Any,
117 | tx_mode: str,
118 | ) -> None:
119 | if channels == (None,):
120 | channels = ()
121 |
122 | if channel_first:
123 | size = channels + size
124 | else:
125 | size = size + channels
126 |
127 | tx = ptx.ConvertImageDtype(dtype=to_dtype)
128 |
129 | if high == 1.0:
130 | img = np.random.uniform(low=0, high=1.0, size=size).astype(high_type)
131 | else:
132 | img = np.random.randint(low=0, high=256, size=size).astype(np.uint8)
133 |
134 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
135 |
136 | with pytest.raises(TypeError):
137 | aug_1, params_1 = tx(img, orig_params)
138 |
139 |
140 | @pytest.mark.parametrize(
141 | "size, channels, channel_first, high, high_type, to_dtype, tx_mode",
142 | list(
143 | product(
144 | [
145 | (
146 | 224,
147 | 224,
148 | ),
149 | ],
150 | [
151 | (None,),
152 | (1,),
153 | (3,),
154 | (4,),
155 | ],
156 | [True, False],
157 | [1.0, 256],
158 | [
159 | np.float16,
160 | np.float32,
161 | np.float64,
162 | np.uint8,
163 | np.int8,
164 | np.int16,
165 | np.int32,
166 | np.int64,
167 | ],
168 | [
169 | torch.int,
170 | torch.int8,
171 | torch.int16,
172 | torch.int32,
173 | torch.int64,
174 | torch.float,
175 | torch.float16,
176 | torch.float32,
177 | torch.float64,
178 | ],
179 | ["CASCADE", "CONSUME"],
180 | )
181 | ),
182 | )
183 | def test_tx_on_torch_tensors(
184 | size: t.Tuple[int],
185 | channels: t.Tuple[int],
186 | channel_first: bool,
187 | high: t.Union,
188 | high_type: t.Any,
189 | to_dtype: t.Any,
190 | tx_mode: str,
191 | ) -> None:
192 | class DummyContext(object):
193 | def __enter__(self):
194 | pass
195 |
196 | def __exit__(self, exc_type, exc_val, exc_tb):
197 | pass
198 |
199 | if channels == (None,):
200 | channels = ()
201 |
202 | if channel_first:
203 | size = channels + size
204 | else:
205 | size = size + channels
206 |
207 | tx = ptx.ConvertImageDtype(dtype=to_dtype)
208 |
209 | if high == 1.0:
210 | img = torch.from_numpy(
211 | np.random.uniform(low=0, high=1.0, size=size).astype(high_type)
212 | )
213 | else:
214 | img = torch.from_numpy(
215 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
216 | )
217 |
218 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
219 |
220 | # Check dtypes.
221 | img_dtype_name = img.dtype
222 | out_dtype_name = to_dtype
223 |
224 | if (img_dtype_name, to_dtype) in [
225 | (torch.float32, torch.int32),
226 | (torch.float32, torch.int64),
227 | (torch.float64, torch.int64),
228 | ]:
229 | context = pytest.raises(RuntimeError)
230 | else:
231 | context = DummyContext()
232 |
233 | with context:
234 |
235 | aug_1, params_1 = tx(img, orig_params)
236 |
237 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
238 |
239 | aug_3, params_3 = tx.consume_transform(img, orig_params)
240 |
241 | assert len(list(aug_1.shape)) == 3 or len(list(aug_1.shape)) == 2
242 |
243 | assert orig_params == params_1
244 | assert orig_params == params_2
245 | assert orig_params == params_3
246 |
247 | assert torch.all(torch.eq(aug_1, aug_2))
248 | assert torch.all(torch.eq(aug_1, aug_3))
249 |
250 | id_params = tx.get_default_params(img=img, processed=True)
251 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
252 | if id_aug_img.dtype == img.dtype:
253 | assert id_aug_rem_params == ()
254 | assert torch.all(torch.eq(id_aug_img, img))
255 |
256 |
257 | # Main.
258 | if __name__ == "__main__":
259 | pass
260 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_ElasticTransform.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | from itertools import product
19 |
20 | import numpy as np
21 |
22 | import parameterized_transforms.transforms as ptx
23 | import parameterized_transforms.core as ptc
24 |
25 | import pytest
26 |
27 | import PIL.Image as Image
28 | import PIL.ImageChops as ImageChops
29 |
30 | import torch
31 | import torchvision.transforms.functional as tv_fn
32 |
33 | import typing as t
34 |
35 |
36 | @pytest.mark.parametrize(
37 | "image_shape, alpha, sigma, interpolation, fill, tx_mode, channels",
38 | list(
39 | product(
40 | [
41 | (
42 | 224,
43 | 224,
44 | ),
45 | ], # image_shape
46 | [
47 | 0.0,
48 | 50.0,
49 | [25.0, 100.0],
50 | ], # alpha
51 | [
52 | 1.0,
53 | 0.5,
54 | 5.0,
55 | ], # sigma
56 | [
57 | tv_fn.InterpolationMode.NEAREST,
58 | tv_fn.InterpolationMode.BILINEAR,
59 | tv_fn.InterpolationMode.BICUBIC,
60 | # # The modes below give error in applying displacements.
61 | # # tv_fn.InterpolationMode.BOX,
62 | # # tv_fn.InterpolationMode.LANCZOS,
63 | # # tv_fn.InterpolationMode.HAMMING,
64 | ], # interpolation
65 | [0.0,], # fill
66 | ["CASCADE", "CONSUME"], # tx_mode
67 | [(3,), (4,), (None,)], # channels
68 | )
69 | ),
70 | )
71 | def test_tx_on_PIL_images(
72 | image_shape: t.Union[int, t.List[int], t.Tuple[int, int]],
73 | alpha: t.Union[float, t.List[float], t.Tuple[float, float]],
74 | sigma: t.Union[float, t.List[float], t.Tuple[float, float]],
75 | interpolation: tv_fn.InterpolationMode,
76 | fill: t.Union[float, t.Sequence[float]],
77 | tx_mode: ptc.TRANSFORM_MODE_TYPE,
78 | channels: t.Tuple[t.Optional[int]],
79 | ):
80 | if channels == (None,):
81 | channels = ()
82 |
83 | size = image_shape + channels
84 |
85 | tx = ptx.ElasticTransform(
86 | image_shape=image_shape,
87 | alpha=alpha,
88 | sigma=sigma,
89 | interpolation=interpolation,
90 | fill=fill,
91 | tx_mode=tx_mode,
92 | )
93 |
94 | img = Image.fromarray(
95 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
96 | )
97 |
98 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
99 |
100 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
101 | assert len(params_1) - len(orig_params) == tx.param_count
102 |
103 | orig_params = ()
104 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
105 | aug_3, params_3 = tx.consume_transform(img, params_2)
106 |
107 | assert orig_params == params_3
108 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
109 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
110 |
111 | id_params = tx.get_default_params(img=img, processed=True)
112 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
113 | assert id_aug_rem_params == ()
114 | assert not ImageChops.difference(id_aug_img, img).getbbox()
115 |
116 |
117 | if __name__ == "__main__":
118 | pass
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_FiveCrop.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 |
25 | import numpy as np
26 |
27 | import PIL.Image as Image
28 | import PIL.ImageChops as ImageChops
29 |
30 | from itertools import product
31 |
32 | import typing as t
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "size, to_size, channels, tx_mode",
37 | list(
38 | product(
39 | [
40 | (
41 | 224,
42 | 224,
43 | ),
44 | (
45 | 32,
46 | 32,
47 | ),
48 | ],
49 | [
50 | # 224, # This gives error.
51 | 28,
52 | 24,
53 | 3,
54 | 8,
55 | [24, 24],
56 | [3, 3],
57 | [8, 8],
58 | (24, 24),
59 | (3, 3),
60 | (8, 8),
61 | [24, 3],
62 | [3, 8],
63 | (8, 24),
64 | ],
65 | [(3,), (4,), (None,)],
66 | ["CASCADE", "CONSUME"],
67 | )
68 | ),
69 | )
70 | def test_tx_on_PIL_images(
71 | size: t.Tuple[int],
72 | to_size: t.Union[int, t.List[int], t.Tuple[int, int]],
73 | channels: t.Tuple[int],
74 | tx_mode: str,
75 | ) -> None:
76 | if channels == (None,):
77 | channels = ()
78 |
79 | size = size + channels
80 |
81 | tx = ptx.FiveCrop(
82 | size=to_size,
83 | tx_mode=tx_mode,
84 | )
85 |
86 | img = Image.fromarray(
87 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
88 | )
89 |
90 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
91 |
92 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
93 | assert len(params_1) - len(orig_params) == tx.param_count
94 |
95 | orig_params = ()
96 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
97 | aug_3, params_3 = tx.consume_transform(img, params_2)
98 |
99 | assert orig_params == params_3
100 | assert all(
101 | [
102 | not ImageChops.difference(aug_2_component, aug_3_component).getbbox()
103 | for aug_2_component, aug_3_component in zip(aug_2, aug_3)
104 | ]
105 | )
106 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
107 |
108 | id_params = tx.get_default_params(img=img, processed=True)
109 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
110 | if img.size == id_aug_img[0].size:
111 | assert all(
112 | not ImageChops.difference(img, an_id_aug_img).getbbox()
113 | for an_id_aug_img in id_aug_img
114 | )
115 | assert id_aug_rem_params == ()
116 |
117 |
118 | @pytest.mark.parametrize(
119 | "size, to_size, channels, tx_mode",
120 | list(
121 | product(
122 | [
123 | (
124 | 224,
125 | 224,
126 | ),
127 | (
128 | 32,
129 | 32,
130 | ),
131 | ],
132 | [
133 | # 224, # This gives error as crop size > image size is possible
134 | 28,
135 | 24,
136 | 3,
137 | 8,
138 | [24, 24],
139 | [3, 3],
140 | [8, 8],
141 | (24, 24),
142 | (3, 3),
143 | (8, 8),
144 | [24, 3],
145 | [3, 8],
146 | (8, 24),
147 | ],
148 | [(3,), (4,), (None,)],
149 | ["CASCADE", "CONSUME"],
150 | )
151 | ),
152 | )
153 | def test_tx_on_torch_tensors(
154 | size: t.Tuple[int],
155 | to_size: t.Union[int, t.List[int], t.Tuple[int, int]],
156 | channels: t.Tuple[int],
157 | tx_mode: str,
158 | ) -> None:
159 | if channels == (None,):
160 | channels = ()
161 |
162 | size = channels + size
163 |
164 | tx = ptx.FiveCrop(
165 | size=to_size,
166 | tx_mode=tx_mode,
167 | )
168 |
169 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size))
170 |
171 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
172 |
173 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
174 | assert len(params_1) - len(orig_params) == tx.param_count
175 |
176 | orig_params = ()
177 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
178 | aug_3, params_3 = tx.consume_transform(img, params_2)
179 |
180 | assert orig_params == params_3
181 | assert all(
182 | [
183 | torch.all(torch.eq(aug_2_component, aug_3_component))
184 | for aug_2_component, aug_3_component in zip(aug_2, aug_3)
185 | ]
186 | )
187 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
188 |
189 | id_params = tx.get_default_params(img=img, processed=True)
190 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
191 | if img.shape == id_aug_img[0].shape:
192 | assert all(
193 | torch.all(torch.eq(img, an_id_aug_img)) for an_id_aug_img in id_aug_img
194 | )
195 | assert id_aug_rem_params == ()
196 |
197 |
198 | # Main.
199 | if __name__ == "__main__":
200 | pass
201 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_GaussianBlur.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 |
25 | import numpy as np
26 |
27 | import PIL.Image as Image
28 | import PIL.ImageChops as ImageChops
29 |
30 | from itertools import product
31 |
32 | import typing as t
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "size, kernel_size, sigma, channels, tx_mode",
37 | list(
38 | product(
39 | [
40 | (
41 | 224,
42 | 224,
43 | ),
44 | ],
45 | [
46 | 1,
47 | 3,
48 | # 4, # kernel size should be odd
49 | 5,
50 | 11,
51 | ],
52 | [
53 | 1,
54 | 0.1,
55 | 0.01,
56 | 10.0,
57 | ],
58 | [(3,), (4,), (None,)],
59 | ["CASCADE", "CONSUME"],
60 | )
61 | ),
62 | )
63 | def test_tx_on_PIL_images(
64 | size: t.Tuple[int],
65 | kernel_size: t.Union[int, t.List[int], t.Tuple[int, int]],
66 | sigma: t.Union[float, t.List[float], t.Tuple[float, float]],
67 | channels: t.Tuple[int],
68 | tx_mode: str,
69 | ) -> None:
70 |
71 | if channels == (None,):
72 | channels = ()
73 |
74 | size = size + channels
75 |
76 | tx = ptx.GaussianBlur(
77 | kernel_size=kernel_size,
78 | sigma=sigma,
79 | tx_mode=tx_mode,
80 | )
81 |
82 | img = Image.fromarray(
83 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
84 | )
85 |
86 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
87 |
88 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
89 | assert len(params_1) - len(orig_params) == tx.param_count
90 |
91 | orig_params = ()
92 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
93 | aug_3, params_3 = tx.consume_transform(img, params_2)
94 |
95 | assert orig_params == params_3
96 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
97 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
98 |
99 | id_params = tx.get_default_params(img=img, processed=True)
100 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
101 | assert id_aug_rem_params == ()
102 | assert not ImageChops.difference(id_aug_img, img).getbbox()
103 |
104 |
105 | @pytest.mark.parametrize(
106 | "size, kernel_size, sigma, channels, tx_mode",
107 | list(
108 | product(
109 | [
110 | (
111 | 224,
112 | 224,
113 | ),
114 | ],
115 | [
116 | 1,
117 | 3,
118 | # 4, # kernel size should be odd
119 | 5,
120 | 11,
121 | ],
122 | [
123 | 1,
124 | 0.1,
125 | 0.01,
126 | 10.0,
127 | ],
128 | [
129 | (3,),
130 | (4,),
131 | ],
132 | ["CASCADE", "CONSUME"],
133 | )
134 | ),
135 | )
136 | def test_tx_on_torch_tensors(
137 | size: t.Tuple[int],
138 | kernel_size: t.Union[int, t.List[int], t.Tuple[int, int]],
139 | sigma: t.Union[float, t.List[float], t.Tuple[float, float]],
140 | channels: t.Tuple[int],
141 | tx_mode: str,
142 | ) -> None:
143 |
144 | if channels == (None,):
145 | channels = ()
146 |
147 | size = channels + size
148 |
149 | tx = ptx.GaussianBlur(
150 | kernel_size=kernel_size,
151 | sigma=sigma,
152 | tx_mode=tx_mode,
153 | )
154 |
155 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size))
156 |
157 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
158 |
159 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
160 | assert len(params_1) - len(orig_params) == tx.param_count
161 |
162 | orig_params = ()
163 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
164 | aug_3, params_3 = tx.consume_transform(img, params_2)
165 |
166 | assert orig_params == params_3
167 | assert torch.all(torch.eq(aug_2, aug_3))
168 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
169 |
170 | id_params = tx.get_default_params(img=img, processed=True)
171 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
172 | assert id_aug_rem_params == ()
173 | assert torch.all(torch.eq(id_aug_img, img))
174 |
175 |
176 | # Main.
177 | if __name__ == "__main__":
178 | pass
179 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_Grayscale.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 |
25 | import numpy as np
26 |
27 | import PIL.Image as Image
28 | import PIL.ImageChops as ImageChops
29 |
30 | from itertools import product
31 |
32 | import typing as t
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "size, to_channels, channels, tx_mode",
37 | list(
38 | product(
39 | [
40 | (
41 | 224,
42 | 224,
43 | ),
44 | ],
45 | [
46 | 1,
47 | 3,
48 | ],
49 | [(3,), (4,), (None,)],
50 | ["CASCADE", "CONSUME"],
51 | )
52 | ),
53 | )
54 | def test_tx_on_PIL_images(
55 | size: t.Tuple[int],
56 | to_channels: int,
57 | channels: t.Tuple[int],
58 | tx_mode: str,
59 | ) -> None:
60 |
61 | if channels == (None,):
62 | channels = ()
63 |
64 | size = size + channels
65 |
66 | tx = ptx.Grayscale(
67 | num_output_channels=to_channels,
68 | tx_mode=tx_mode,
69 | )
70 |
71 | img = Image.fromarray(
72 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
73 | )
74 |
75 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
76 |
77 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
78 | assert len(params_1) - len(orig_params) == tx.param_count
79 |
80 | orig_params = ()
81 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
82 | aug_3, params_3 = tx.consume_transform(img, params_2)
83 |
84 | assert orig_params == params_3
85 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
86 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
87 |
88 | # # This is a transform that always alters the image.
89 | # id_params = tx.get_default_params(img=img, processed=True)
90 | # id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
91 | # assert id_aug_rem_params == ()
92 | # assert not ImageChops.difference(id_aug_img, img).getbbox()
93 |
94 |
95 | @pytest.mark.parametrize(
96 | "size, to_channels, channels, tx_mode",
97 | list(
98 | product(
99 | [
100 | (
101 | 224,
102 | 224,
103 | ),
104 | ],
105 | [
106 | 1,
107 | 3,
108 | ],
109 | [(3,)],
110 | ["CASCADE", "CONSUME"],
111 | )
112 | ),
113 | )
114 | def test_tx_on_torch_tensors(
115 | size: t.Tuple[int],
116 | to_channels: int,
117 | channels: t.Tuple[int],
118 | tx_mode: str,
119 | ) -> None:
120 |
121 | if channels == (None,):
122 | channels = ()
123 |
124 | size = channels + size
125 |
126 | tx = ptx.Grayscale(
127 | num_output_channels=to_channels,
128 | tx_mode=tx_mode,
129 | )
130 |
131 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size))
132 |
133 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
134 |
135 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
136 | assert len(params_1) - len(orig_params) == tx.param_count
137 |
138 | orig_params = ()
139 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
140 | aug_3, params_3 = tx.consume_transform(img, params_2)
141 |
142 | assert orig_params == params_3
143 | assert torch.all(torch.eq(aug_2, aug_3))
144 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
145 |
146 | # # This is a transform that always alters the image.
147 | # id_params = tx.get_default_params(img=img, processed=True)
148 | # id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
149 | # assert id_aug_rem_params == ()
150 | # assert torch.all(torch.eq(id_aug_img, img))
151 |
152 |
153 | # Main.
154 | if __name__ == "__main__":
155 | pass
156 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_Lambda.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 |
25 | import numpy as np
26 |
27 | import PIL.Image as Image
28 | import PIL.ImageChops as ImageChops
29 |
30 | from itertools import product
31 |
32 | import typing as t
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "size, lambd, channels, tx_mode",
37 | list(
38 | product(
39 | [
40 | (
41 | 224,
42 | 224,
43 | ),
44 | ],
45 | [
46 | lambda x: x,
47 | ],
48 | [(3,), (4,), (None,)],
49 | ["CASCADE", "CONSUME"],
50 | )
51 | ),
52 | )
53 | def test_tx_on_PIL_images(
54 | size: t.Tuple[int],
55 | lambd: t.Callable,
56 | channels: int,
57 | tx_mode: str,
58 | ) -> None:
59 |
60 | if channels == (None,):
61 | channels = ()
62 |
63 | size = size + channels
64 |
65 | tx = ptx.Lambda(lambd=lambd, tx_mode=tx_mode)
66 |
67 | img = Image.fromarray(
68 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
69 | )
70 |
71 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
72 |
73 | aug_1, params_1 = tx(img, orig_params)
74 |
75 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
76 |
77 | aug_3, params_3 = tx.consume_transform(img, orig_params)
78 |
79 | assert orig_params == params_1
80 | assert orig_params == params_2
81 | assert orig_params == params_3
82 |
83 | assert not ImageChops.difference(aug_1, aug_2).getbbox()
84 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
85 |
86 | # # This transform does NOT have identity params.
87 | # # With given identity function, we get the same image back.
88 | id_params = tx.get_default_params(img=img, processed=True)
89 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
90 | assert id_aug_rem_params == ()
91 | assert not ImageChops.difference(id_aug_img, img).getbbox()
92 |
93 |
94 | @pytest.mark.parametrize(
95 | "size, lambd, channels, tx_mode",
96 | list(
97 | product(
98 | [
99 | (
100 | 224,
101 | 224,
102 | ),
103 | ],
104 | [
105 | lambda x: x,
106 | ],
107 | [(3,), (4,), (None,)],
108 | ["CASCADE", "CONSUME"],
109 | )
110 | ),
111 | )
112 | def test_tx_on_torch_tensors(
113 | size: t.Tuple[int],
114 | lambd: t.Callable,
115 | channels: int,
116 | tx_mode: str,
117 | ) -> None:
118 |
119 | if channels == (None,):
120 | channels = ()
121 |
122 | size = channels + size
123 |
124 | tx = ptx.Lambda(lambd=lambd, tx_mode=tx_mode)
125 |
126 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size))
127 |
128 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
129 |
130 | aug_1, params_1 = tx(img, orig_params)
131 |
132 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
133 |
134 | aug_3, params_3 = tx.consume_transform(img, orig_params)
135 |
136 | assert orig_params == params_1
137 | assert orig_params == params_2
138 | assert orig_params == params_3
139 |
140 | assert torch.all(torch.eq(aug_1, aug_2))
141 | assert torch.all(torch.eq(aug_2, aug_3))
142 |
143 | # # This transform does NOT have identity params.
144 | # # With given identity function, we get the same image back.
145 | id_params = tx.get_default_params(img=img, processed=True)
146 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
147 | assert id_aug_rem_params == ()
148 | assert torch.all(torch.eq(id_aug_img, img))
149 |
150 |
151 | # Main.
152 | if __name__ == "__main__":
153 | pass
154 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_LinearTransformation.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 |
25 | import numpy as np
26 |
27 | from itertools import product
28 |
29 | import typing as t
30 |
31 |
32 | @pytest.mark.parametrize(
33 | "size, channels, tx_mode, repeats",
34 | list(
35 | product(
36 | [
37 | (
38 | 11,
39 | 11,
40 | ),
41 | (
42 | 12,
43 | 12,
44 | ),
45 | (
46 | 13,
47 | 13,
48 | ),
49 | ],
50 | [
51 | (3,),
52 | (4,),
53 | # (None,) # This gives error.
54 | ],
55 | ["CASCADE", "CONSUME"],
56 | [
57 | 100,
58 | ],
59 | )
60 | ),
61 | )
62 | def test_tx_on_torch_tensors(
63 | size: t.Tuple[int],
64 | channels: t.Tuple[int],
65 | tx_mode: str,
66 | repeats: int,
67 | ) -> None:
68 | if channels == (None,):
69 | channels = ()
70 |
71 | size = channels + size
72 |
73 | dim = np.prod(size)
74 |
75 | for _ in range(repeats):
76 |
77 | tx_matrix = torch.Tensor(dim, dim).uniform_().double()
78 | tx_mean = (
79 | torch.Tensor(
80 | dim,
81 | )
82 | .uniform_()
83 | .double()
84 | )
85 |
86 | tx = ptx.LinearTransformation(
87 | transformation_matrix=tx_matrix,
88 | mean_vector=tx_mean,
89 | tx_mode=tx_mode,
90 | )
91 |
92 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size))
93 |
94 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
95 |
96 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
97 | assert len(params_1) - len(orig_params) == tx.param_count
98 |
99 | orig_params = ()
100 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
101 | aug_3, params_3 = tx.consume_transform(img, params_2)
102 |
103 | assert orig_params == params_3
104 | assert torch.all(torch.eq(aug_2, aug_3))
105 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
106 |
107 | # # The identity params SHOULD be `sigma=0` but this does NOT work correctly in code.
108 | # id_params = tx.get_default_params(img=img, processed=True)
109 | # id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
110 | # assert id_aug_rem_params == ()
111 | # assert torch.all(torch.eq(id_aug_img, img))
112 |
113 |
114 | # Main.
115 | if __name__ == "__main__":
116 | pass
117 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_PILToTensor.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | # Dependencies.
18 | import pytest
19 | import torch
20 |
21 | import parameterized_transforms.transforms as ptx
22 |
23 | import numpy as np
24 |
25 | import PIL.Image as Image
26 |
27 | from itertools import product
28 |
29 | import typing as t
30 |
31 |
32 | @pytest.mark.parametrize(
33 | "size, channels, tx_mode",
34 | list(
35 | product(
36 | [
37 | (
38 | 224,
39 | 224,
40 | ),
41 | ],
42 | [
43 | (None,),
44 | (3,),
45 | (4,),
46 | ],
47 | ["CASCADE", "CONSUME"],
48 | )
49 | ),
50 | )
51 | def test_tx_on_PIL_images_1(
52 | size: t.Tuple[int],
53 | channels: t.Tuple[int],
54 | tx_mode: str,
55 | ) -> None:
56 | if channels == (None,):
57 | channels = ()
58 |
59 | size = size + channels
60 |
61 | tx = ptx.PILToTensor()
62 |
63 | img = Image.fromarray(
64 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
65 | )
66 |
67 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
68 |
69 | aug_1, params_1 = tx(img, orig_params)
70 |
71 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
72 |
73 | aug_3, params_3 = tx.consume_transform(img, orig_params)
74 |
75 | assert len(list(aug_1.shape)) == 3 # [C, H, W] format ensured
76 | assert aug_1.shape[0] == 1 or aug_1.shape[0] == 3 or aug_1.shape[0] == 4
77 | assert 0 <= torch.min(aug_1).item() <= 255.0
78 | assert 0 <= torch.max(aug_1).item() <= 255.0
79 |
80 | assert orig_params == params_1
81 | assert orig_params == params_2
82 | assert orig_params == params_3
83 |
84 | assert torch.all(torch.eq(aug_1, aug_2))
85 | assert torch.all(torch.eq(aug_1, aug_3))
86 |
87 | # # This transform does NOT have identity params.
88 | # id_params = tx.get_default_params(img=img, processed=True)
89 | # id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
90 | # assert id_aug_rem_params == ()
91 | # assert not ImageChops.difference(id_aug_img, img).getbbox()
92 |
93 |
94 | @pytest.mark.parametrize(
95 | "size, channels, tx_mode",
96 | list(
97 | product(
98 | [
99 | (
100 | 224,
101 | 224,
102 | ),
103 | ],
104 | [
105 | (1,),
106 | ],
107 | ["CASCADE", "CONSUME"],
108 | )
109 | ),
110 | )
111 | def test_tx_on_PIL_images_2(
112 | size: t.Tuple[int],
113 | channels: t.Tuple[int],
114 | tx_mode: str,
115 | ) -> None:
116 | with pytest.raises(TypeError):
117 | if channels == (None,):
118 | channels = ()
119 |
120 | size = size + channels
121 |
122 | tx = ptx.PILToTensor()
123 |
124 | img = Image.fromarray(
125 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
126 | )
127 |
128 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
129 |
130 | aug_1, params_1 = tx(img, orig_params)
131 |
132 |
133 | @pytest.mark.parametrize(
134 | "size, channels, channel_first, high, high_type, tx_mode",
135 | list(
136 | product(
137 | [
138 | (
139 | 224,
140 | 224,
141 | ),
142 | ],
143 | [
144 | (None,),
145 | (1,),
146 | (3,),
147 | (4,),
148 | ],
149 | [True, False],
150 | [1.0, 256],
151 | [
152 | np.float16,
153 | np.float32,
154 | np.float64,
155 | np.uint8,
156 | np.int8,
157 | np.int16,
158 | np.int32,
159 | np.int64,
160 | ],
161 | ["CASCADE", "CONSUME"],
162 | )
163 | ),
164 | )
165 | def test_tx_on_ndarrays(
166 | size: t.Tuple[int],
167 | channels: t.Tuple[int],
168 | channel_first: bool,
169 | high: t.Union,
170 | high_type: t.Any,
171 | tx_mode: str,
172 | ) -> None:
173 | if channels == (None,):
174 | channels = ()
175 |
176 | if channel_first:
177 | size = channels + size
178 | else:
179 | size = size + channels
180 |
181 | tx = ptx.PILToTensor()
182 |
183 | if high == 1.0:
184 | img = np.random.uniform(low=0, high=1.0, size=size).astype(high_type)
185 | else:
186 | img = np.random.randint(low=0, high=256, size=size).astype(np.uint8)
187 |
188 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
189 |
190 | with pytest.raises(TypeError):
191 |
192 | aug_1, params_1 = tx(img, orig_params)
193 |
194 |
195 | @pytest.mark.parametrize(
196 | "size, channels, channel_first, high, high_type, tx_mode",
197 | list(
198 | product(
199 | [
200 | (
201 | 224,
202 | 224,
203 | ),
204 | ],
205 | [
206 | (None,),
207 | (1,),
208 | (3,),
209 | (4,),
210 | ],
211 | [True, False],
212 | [1.0, 256],
213 | [
214 | np.float16,
215 | np.float32,
216 | np.float64,
217 | np.uint8,
218 | np.int8,
219 | np.int16,
220 | np.int32,
221 | np.int64,
222 | ],
223 | ["CASCADE", "CONSUME"],
224 | )
225 | ),
226 | )
227 | def test_tx_on_torch_tensors(
228 | size: t.Tuple[int],
229 | channels: t.Tuple[int],
230 | channel_first: bool,
231 | high: t.Union,
232 | high_type: t.Any,
233 | tx_mode: str,
234 | ) -> None:
235 | if channels == (None,):
236 | channels = ()
237 |
238 | if channel_first:
239 | size = channels + size
240 | else:
241 | size = size + channels
242 |
243 | tx = ptx.PILToTensor()
244 |
245 | if high == 1.0:
246 | img = torch.from_numpy(
247 | np.random.uniform(low=0, high=1.0, size=size).astype(high_type)
248 | )
249 | else:
250 | img = torch.from_numpy(
251 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
252 | )
253 |
254 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
255 |
256 | with pytest.raises(TypeError):
257 |
258 | aug_1, params_1 = tx(img, orig_params)
259 |
260 |
261 | # Main.
262 | if __name__ == "__main__":
263 | pass
264 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_RandomAdjustSharpness.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 |
25 | import numpy as np
26 |
27 | import PIL.Image as Image
28 | import PIL.ImageChops as ImageChops
29 |
30 | from itertools import product
31 |
32 | import typing as t
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "size, sharpness_factor, p, channels, tx_mode",
37 | list(
38 | product(
39 | [
40 | (
41 | 224,
42 | 224,
43 | ),
44 | ],
45 | [0.1, 0.5, 1.0, 2.0],
46 | [0.0, 0.5, 1.0],
47 | [
48 | (3,),
49 | (4,),
50 | (None,),
51 | ],
52 | ["CASCADE", "CONSUME"],
53 | )
54 | ),
55 | )
56 | def test_tx_on_PIL_images(
57 | size: t.Tuple[int],
58 | sharpness_factor: float,
59 | p: float,
60 | channels: t.Tuple[int],
61 | tx_mode: str,
62 | ) -> None:
63 | if channels == (None,):
64 | channels = ()
65 |
66 | size = size + channels
67 |
68 | tx = ptx.RandomAdjustSharpness(
69 | sharpness_factor=sharpness_factor,
70 | p=p,
71 | tx_mode=tx_mode,
72 | )
73 |
74 | img = Image.fromarray(
75 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
76 | )
77 |
78 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
79 |
80 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
81 | assert len(params_1) - len(orig_params) == tx.param_count
82 |
83 | orig_params = ()
84 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
85 | aug_3, params_3 = tx.consume_transform(img, params_2)
86 |
87 | assert orig_params == params_3
88 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
89 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
90 |
91 | id_params = tx.get_default_params(img=img, processed=True)
92 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
93 | assert id_aug_rem_params == ()
94 | assert not ImageChops.difference(id_aug_img, img).getbbox()
95 |
96 |
97 | @pytest.mark.parametrize(
98 | "size, sharpness_factor, p, channels, tx_mode",
99 | list(
100 | product(
101 | [
102 | (
103 | 224,
104 | 224,
105 | ),
106 | ],
107 | [0.1, 0.5, 1.0, 2.0],
108 | [0.0, 0.5, 1.0],
109 | [
110 | (3,),
111 | # # These give error
112 | # (4,),
113 | # (None,),
114 | ],
115 | ["CASCADE", "CONSUME"],
116 | )
117 | ),
118 | )
119 | def test_tx_on_torch_tensors(
120 | size: t.Tuple[int],
121 | sharpness_factor: float,
122 | p: float,
123 | channels: t.Tuple[int],
124 | tx_mode: str,
125 | ) -> None:
126 | if channels == (None,):
127 | channels = ()
128 |
129 | size = channels + size
130 |
131 | tx = ptx.RandomAdjustSharpness(
132 | sharpness_factor=sharpness_factor,
133 | p=p,
134 | tx_mode=tx_mode,
135 | )
136 |
137 | img = torch.from_numpy(np.random.randint(low=0, high=1.0, size=size))
138 |
139 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
140 |
141 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
142 | assert len(params_1) - len(orig_params) == tx.param_count
143 |
144 | orig_params = ()
145 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
146 | aug_3, params_3 = tx.consume_transform(img, params_2)
147 |
148 | assert orig_params == params_3
149 | assert torch.all(torch.eq(aug_2, aug_3))
150 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
151 |
152 | id_params = tx.get_default_params(img=img, processed=True)
153 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
154 | assert id_aug_rem_params == ()
155 | assert torch.all(torch.eq(id_aug_img, img))
156 |
157 |
158 | # Main.
159 | if __name__ == "__main__":
160 | pass
161 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_RandomAutocontrast.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 |
25 | import numpy as np
26 |
27 | import PIL.Image as Image
28 | import PIL.ImageChops as ImageChops
29 |
30 | from itertools import product
31 |
32 | import typing as t
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "size, p, channels, tx_mode",
37 | list(
38 | product(
39 | [
40 | (
41 | 224,
42 | 224,
43 | ),
44 | ],
45 | [0.0, 0.5, 1.0],
46 | [
47 | (3,),
48 | # # These are not allowed
49 | # (4,),
50 | # (None,),
51 | ],
52 | ["CASCADE", "CONSUME"],
53 | )
54 | ),
55 | )
56 | def test_tx_on_PIL_images(
57 | size: t.Tuple[int],
58 | p: float,
59 | channels: t.Tuple[int],
60 | tx_mode: str,
61 | ) -> None:
62 | if channels == (None,):
63 | channels = ()
64 |
65 | size = size + channels
66 |
67 | tx = ptx.RandomAutocontrast(
68 | p=p,
69 | tx_mode=tx_mode,
70 | )
71 |
72 | img = Image.fromarray(
73 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
74 | )
75 |
76 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
77 |
78 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
79 | assert len(params_1) - len(orig_params) == tx.param_count
80 |
81 | orig_params = ()
82 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
83 | aug_3, params_3 = tx.consume_transform(img, params_2)
84 |
85 | assert orig_params == params_3
86 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
87 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
88 |
89 | id_params = tx.get_default_params(img=img, processed=True)
90 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
91 | assert id_aug_rem_params == ()
92 | assert not ImageChops.difference(id_aug_img, img).getbbox()
93 |
94 |
95 | @pytest.mark.parametrize(
96 | "size, p, channels, tx_mode",
97 | list(
98 | product(
99 | [
100 | (
101 | 224,
102 | 224,
103 | ),
104 | ],
105 | [0.0, 0.5, 1.0],
106 | [
107 | (3,),
108 | # # These are not allowed
109 | # (4,),
110 | # (None,),
111 | ],
112 | ["CASCADE", "CONSUME"],
113 | )
114 | ),
115 | )
116 | def test_tx_on_torch_tensors(
117 | size: t.Tuple[int],
118 | p: float,
119 | channels: t.Tuple[int],
120 | tx_mode: str,
121 | ) -> None:
122 | if channels == (None,):
123 | channels = ()
124 |
125 | size = channels + size
126 |
127 | tx = ptx.RandomAutocontrast(
128 | p=p,
129 | tx_mode=tx_mode,
130 | )
131 |
132 | img = torch.from_numpy(np.random.randint(low=0, high=1.0, size=size))
133 |
134 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
135 |
136 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
137 | assert len(params_1) - len(orig_params) == tx.param_count
138 |
139 | orig_params = ()
140 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
141 | aug_3, params_3 = tx.consume_transform(img, params_2)
142 |
143 | assert orig_params == params_3
144 | assert torch.all(torch.eq(aug_2, aug_3))
145 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
146 |
147 | id_params = tx.get_default_params(img=img, processed=True)
148 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
149 | assert id_aug_rem_params == ()
150 | assert torch.all(torch.eq(id_aug_img, img))
151 |
152 |
153 | # Main.
154 | if __name__ == "__main__":
155 | pass
156 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_RandomEqualize.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 |
25 | import numpy as np
26 |
27 | import PIL.Image as Image
28 | import PIL.ImageChops as ImageChops
29 |
30 | from itertools import product
31 |
32 | import typing as t
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "size, p, channels, tx_mode",
37 | list(
38 | product(
39 | [
40 | (
41 | 224,
42 | 224,
43 | ),
44 | ],
45 | [0.0, 0.5, 1.0],
46 | [
47 | (3,),
48 | # # This is
49 | # (4,),
50 | # (None,),
51 | ],
52 | ["CASCADE", "CONSUME"],
53 | )
54 | ),
55 | )
56 | def test_tx_on_PIL_images(
57 | size: t.Tuple[int],
58 | p: float,
59 | channels: t.Tuple[int],
60 | tx_mode: str,
61 | ) -> None:
62 | if channels == (None,):
63 | channels = ()
64 |
65 | size = size + channels
66 |
67 | tx = ptx.RandomEqualize(
68 | p=p,
69 | tx_mode=tx_mode,
70 | )
71 |
72 | img = Image.fromarray(
73 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
74 | )
75 |
76 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
77 |
78 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
79 | assert len(params_1) - len(orig_params) == tx.param_count
80 |
81 | orig_params = ()
82 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
83 | aug_3, params_3 = tx.consume_transform(img, params_2)
84 |
85 | assert orig_params == params_3
86 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
87 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
88 |
89 | id_params = tx.get_default_params(img=img, processed=True)
90 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
91 | assert id_aug_rem_params == ()
92 | assert not ImageChops.difference(id_aug_img, img).getbbox()
93 |
94 |
95 | @pytest.mark.parametrize(
96 | "size, p, channels, tx_mode",
97 | list(
98 | product(
99 | [
100 | (
101 | 224,
102 | 224,
103 | ),
104 | ],
105 | [0.0, 0.5, 1.0],
106 | [
107 | (3,),
108 | # # These give error
109 | # (4,),
110 | # (None,),
111 | ],
112 | ["CASCADE", "CONSUME"],
113 | )
114 | ),
115 | )
116 | def test_tx_on_torch_tensors(
117 | size: t.Tuple[int],
118 | p: float,
119 | channels: t.Tuple[int],
120 | tx_mode: str,
121 | ) -> None:
122 | if channels == (None,):
123 | channels = ()
124 |
125 | size = channels + size
126 |
127 | tx = ptx.RandomEqualize(
128 | p=p,
129 | tx_mode=tx_mode,
130 | )
131 |
132 | img = torch.from_numpy(
133 | np.random.uniform(low=0, high=256, size=size).astype(np.uint8)
134 | )
135 |
136 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
137 |
138 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
139 | assert len(params_1) - len(orig_params) == tx.param_count
140 |
141 | orig_params = ()
142 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
143 | aug_3, params_3 = tx.consume_transform(img, params_2)
144 |
145 | assert orig_params == params_3
146 | assert torch.all(torch.eq(aug_2, aug_3))
147 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
148 |
149 | id_params = tx.get_default_params(img=img, processed=True)
150 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
151 | assert id_aug_rem_params == ()
152 | assert torch.all(torch.eq(id_aug_img, img))
153 |
154 |
155 | # Main.
156 | if __name__ == "__main__":
157 | pass
158 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_RandomErasing.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 | import parameterized_transforms.core as ptc
25 |
26 | import numpy as np
27 |
28 | from itertools import product
29 |
30 | import typing as t
31 |
32 |
33 | @pytest.mark.parametrize(
34 | "size, p, scale, ratio, value, default_params_mode, channels, tx_mode",
35 | list(
36 | product(
37 | [
38 | (
39 | 224,
40 | 224,
41 | ),
42 | ],
43 | # p
44 | [0.0, 0.5, 1.0],
45 | # scale
46 | [
47 | (0.02, 0.33),
48 | ],
49 | # ratio
50 | [
51 | (3.0 / 4, 4.0 / 3),
52 | ],
53 | # value
54 | ["random", 0, 23, 255.0],
55 | # identity params mode
56 | [
57 | ptc.DefaultParamsMode.RANDOMIZED,
58 | ptc.DefaultParamsMode.UNIQUE,
59 | ],
60 | # channels
61 | [
62 | (3,),
63 | (4,),
64 | ],
65 | ["CASCADE", "CONSUME"],
66 | )
67 | ),
68 | )
69 | def test_tx_on_torch_tensors(
70 | size: t.Tuple[int],
71 | p: float,
72 | scale: t.Union[t.Tuple[float, float], t.List[float]],
73 | ratio: t.Union[t.Tuple[float, float], t.List[float]],
74 | value: t.Union[str, int, str, t.List[int], t.Tuple[int, int, int]],
75 | default_params_mode: ptc.DefaultParamsMode,
76 | channels: t.Tuple[int],
77 | tx_mode: str,
78 | ) -> None:
79 |
80 | # PIL-support is NOT guaranteed in torchvision
81 |
82 | if channels == (None,):
83 | channels = ()
84 |
85 | size = channels + size
86 |
87 | tx = ptx.RandomErasing(
88 | p=p,
89 | scale=scale,
90 | ratio=ratio,
91 | value=value,
92 | tx_mode=tx_mode,
93 | default_params_mode=default_params_mode,
94 | )
95 |
96 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size))
97 |
98 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
99 |
100 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
101 | assert len(params_1) - len(orig_params) == tx.param_count
102 |
103 | orig_params = ()
104 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
105 | aug_3, params_3 = tx.consume_transform(img, params_2)
106 |
107 | assert orig_params == params_3
108 |
109 | if value != "random":
110 | assert torch.all(torch.eq(aug_2, aug_3))
111 |
112 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
113 |
114 | id_params = tx.get_default_params(img=img, processed=True)
115 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
116 | if value != "random":
117 | assert id_aug_rem_params == ()
118 | assert torch.all(torch.eq(id_aug_img, img))
119 |
120 |
121 | # Main.
122 | if __name__ == "__main__":
123 | pass
124 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_RandomGrayscale.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 |
25 | import numpy as np
26 |
27 | import PIL.Image as Image
28 | import PIL.ImageChops as ImageChops
29 |
30 | from itertools import product
31 |
32 | import typing as t
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "size, p, channels, tx_mode",
37 | list(
38 | product(
39 | [
40 | (
41 | 224,
42 | 224,
43 | ),
44 | ],
45 | [0.0, 0.5, 1.0],
46 | [
47 | (3,),
48 | # (4,),
49 | (None,),
50 | ],
51 | ["CASCADE", "CONSUME"],
52 | )
53 | ),
54 | )
55 | def test_tx_on_PIL_images(
56 | size: t.Tuple[int],
57 | p: float,
58 | channels: t.Tuple[int],
59 | tx_mode: str,
60 | ) -> None:
61 |
62 | if channels == (None,):
63 | channels = ()
64 |
65 | size = size + channels
66 |
67 | tx = ptx.RandomGrayscale(
68 | p=p,
69 | tx_mode=tx_mode,
70 | )
71 |
72 | img = Image.fromarray(
73 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
74 | )
75 |
76 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
77 |
78 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
79 | assert len(params_1) - len(orig_params) == tx.param_count
80 |
81 | orig_params = ()
82 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
83 | aug_3, params_3 = tx.consume_transform(img, params_2)
84 |
85 | assert orig_params == params_3
86 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
87 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
88 |
89 | id_params = tx.get_default_params(img=img, processed=True)
90 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
91 | assert id_aug_rem_params == ()
92 | assert not ImageChops.difference(id_aug_img, img).getbbox()
93 |
94 |
95 | @pytest.mark.parametrize(
96 | "size, p, channels, tx_mode",
97 | list(
98 | product(
99 | [
100 | (
101 | 224,
102 | 224,
103 | ),
104 | ],
105 | [0.0, 0.5, 1.0],
106 | [
107 | (3,),
108 | ],
109 | ["CASCADE", "CONSUME"],
110 | )
111 | ),
112 | )
113 | def test_tx_on_torch_tensors(
114 | size: t.Tuple[int],
115 | p: int,
116 | channels: t.Tuple[int],
117 | tx_mode: str,
118 | ) -> None:
119 |
120 | if channels == (None,):
121 | channels = ()
122 |
123 | size = channels + size
124 |
125 | tx = ptx.RandomGrayscale(
126 | p=p,
127 | tx_mode=tx_mode,
128 | )
129 |
130 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size))
131 |
132 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
133 |
134 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
135 | assert len(params_1) - len(orig_params) == tx.param_count
136 |
137 | orig_params = ()
138 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
139 | aug_3, params_3 = tx.consume_transform(img, params_2)
140 |
141 | assert orig_params == params_3
142 | assert torch.all(torch.eq(aug_2, aug_3))
143 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
144 |
145 | id_params = tx.get_default_params(img=img, processed=True)
146 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
147 | assert id_aug_rem_params == ()
148 | assert torch.all(torch.eq(id_aug_img, img))
149 |
150 |
151 | # Main.
152 | if __name__ == "__main__":
153 | pass
154 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_RandomHorizontalFlip.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 |
25 | import numpy as np
26 |
27 | import PIL.Image as Image
28 | import PIL.ImageChops as ImageChops
29 |
30 | from itertools import product
31 |
32 | import typing as t
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "size, p, channels, tx_mode",
37 | list(
38 | product(
39 | [
40 | (
41 | 224,
42 | 224,
43 | ),
44 | ],
45 | [0.0, 0.5, 1.0],
46 | [(3,), (4,), (None,)],
47 | ["CASCADE", "CONSUME"],
48 | )
49 | ),
50 | )
51 | def test_tx_on_PIL_images(
52 | size: t.Tuple[int],
53 | p: float,
54 | channels: t.Tuple[int],
55 | tx_mode: str,
56 | ) -> None:
57 |
58 | #
59 | class DummyContext(object):
60 | def __enter__(self):
61 | pass
62 |
63 | def __exit__(self, exc_type, exc_val, exc_tb):
64 | pass
65 |
66 | if channels == (None,):
67 | channels = ()
68 |
69 | size = size + channels
70 |
71 | tx = ptx.RandomHorizontalFlip(
72 | p=p,
73 | tx_mode=tx_mode,
74 | )
75 |
76 | img = Image.fromarray(
77 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
78 | )
79 |
80 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
81 |
82 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
83 | assert len(params_1) - len(orig_params) == tx.param_count
84 |
85 | orig_params = ()
86 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
87 | aug_3, params_3 = tx.consume_transform(img, params_2)
88 |
89 | assert orig_params == params_3
90 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
91 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
92 |
93 | if params_2[-1] == 0:
94 | assert not ImageChops.difference(aug_2, img).getbbox()
95 |
96 | id_params = tx.get_default_params(img=img, processed=True)
97 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
98 | assert id_aug_rem_params == ()
99 | assert not ImageChops.difference(id_aug_img, img).getbbox()
100 |
101 |
102 | @pytest.mark.parametrize(
103 | "size, p, channels, tx_mode",
104 | list(
105 | product(
106 | [
107 | (
108 | 224,
109 | 224,
110 | ),
111 | ],
112 | [0.0, 0.5, 1.0],
113 | [(3,), (4,), (None,)],
114 | ["CASCADE", "CONSUME"],
115 | )
116 | ),
117 | )
118 | def test_tx_on_torch_tensors(
119 | size: t.Tuple[int],
120 | p: float,
121 | channels: t.Tuple[int],
122 | tx_mode: str,
123 | ) -> None:
124 |
125 | #
126 | class DummyContext(object):
127 | def __enter__(self):
128 | pass
129 |
130 | def __exit__(self, exc_type, exc_val, exc_tb):
131 | pass
132 |
133 | if channels == (None,):
134 | channels = ()
135 |
136 | size = channels + size
137 |
138 | tx = ptx.RandomHorizontalFlip(
139 | p=p,
140 | tx_mode=tx_mode,
141 | )
142 |
143 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size))
144 |
145 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
146 |
147 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
148 | assert len(params_1) - len(orig_params) == tx.param_count
149 |
150 | orig_params = ()
151 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
152 | aug_3, params_3 = tx.consume_transform(img, params_2)
153 |
154 | assert orig_params == params_3
155 | assert torch.all(torch.eq(aug_2, aug_3))
156 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
157 |
158 | if params_2[-1] == 0:
159 | assert torch.all(torch.eq(aug_2, img))
160 |
161 | id_params = tx.get_default_params(img=img, processed=True)
162 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
163 | assert id_aug_rem_params == ()
164 | assert torch.all(torch.eq(id_aug_img, img))
165 |
166 |
167 | # Main.
168 | if __name__ == "__main__":
169 | pass
170 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_RandomInvert.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 |
25 | import numpy as np
26 |
27 | import PIL.Image as Image
28 | import PIL.ImageChops as ImageChops
29 |
30 | from itertools import product
31 |
32 | import typing as t
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "size, p, channels, tx_mode",
37 | list(
38 | product(
39 | [
40 | (
41 | 224,
42 | 224,
43 | ),
44 | ],
45 | [0.0, 0.5, 1.0],
46 | [
47 | (3,),
48 | # # These dimensions lead to error.
49 | # (4,),
50 | # (None,)
51 | ],
52 | ["CASCADE", "CONSUME"],
53 | )
54 | ),
55 | )
56 | def test_tx_on_PIL_images(
57 | size: t.Tuple[int],
58 | p: float,
59 | channels: t.Tuple[int],
60 | tx_mode: str,
61 | ) -> None:
62 |
63 | if channels == (None,):
64 | channels = ()
65 |
66 | size = size + channels
67 |
68 | tx = ptx.RandomInvert(
69 | p=p,
70 | tx_mode=tx_mode,
71 | )
72 |
73 | img = Image.fromarray(
74 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
75 | )
76 |
77 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
78 |
79 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
80 | assert len(params_1) - len(orig_params) == tx.param_count
81 |
82 | orig_params = ()
83 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
84 | aug_3, params_3 = tx.consume_transform(img, params_2)
85 |
86 | assert orig_params == params_3
87 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
88 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
89 |
90 | id_params = tx.get_default_params(img=img, processed=True)
91 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
92 | assert id_aug_rem_params == ()
93 | assert not ImageChops.difference(id_aug_img, img).getbbox()
94 |
95 |
96 | @pytest.mark.parametrize(
97 | "size, p, channels, tx_mode",
98 | list(
99 | product(
100 | [
101 | (
102 | 224,
103 | 224,
104 | ),
105 | ],
106 | [0.0, 0.5, 1.0],
107 | [
108 | (1,),
109 | (3,),
110 | # (4,),
111 | # (None, ),
112 | ],
113 | ["CASCADE", "CONSUME"],
114 | )
115 | ),
116 | )
117 | def test_tx_on_torch_tensors(
118 | size: t.Tuple[int],
119 | p: float,
120 | channels: t.Tuple[int],
121 | tx_mode: str,
122 | ) -> None:
123 |
124 | if channels == (None,):
125 | channels = ()
126 |
127 | size = channels + size
128 |
129 | tx = ptx.RandomInvert(
130 | p=p,
131 | tx_mode=tx_mode,
132 | )
133 |
134 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size))
135 |
136 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
137 |
138 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
139 | assert len(params_1) - len(orig_params) == tx.param_count
140 |
141 | orig_params = ()
142 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
143 | aug_3, params_3 = tx.consume_transform(img, params_2)
144 |
145 | assert orig_params == params_3
146 | assert torch.all(torch.eq(aug_2, aug_3))
147 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
148 |
149 | id_params = tx.get_default_params(img=img, processed=True)
150 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
151 | assert id_aug_rem_params == ()
152 | assert torch.all(torch.eq(id_aug_img, img))
153 |
154 |
155 | # Main.
156 | if __name__ == "__main__":
157 | pass
158 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_RandomPerspective.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 | import torchvision.transforms.functional as tv_fn
23 |
24 | import parameterized_transforms.transforms as ptx
25 |
26 | import numpy as np
27 |
28 | import PIL.Image as Image
29 | import PIL.ImageChops as ImageChops
30 |
31 | from itertools import product
32 |
33 | import typing as t
34 |
35 |
36 | @pytest.mark.parametrize(
37 | "repetition, size, distortion_scale, p, interpolation_mode, fill, channels, tx_mode",
38 | list(
39 | product(
40 | [idx for idx in range(2)],
41 | [
42 | (
43 | 224,
44 | 224,
45 | ),
46 | ],
47 | [0.0, 0.5, 1.0],
48 | [0.0, 0.5, 1.0],
49 | [
50 | tv_fn.InterpolationMode.NEAREST,
51 | tv_fn.InterpolationMode.BILINEAR,
52 | tv_fn.InterpolationMode.BICUBIC,
53 | # # The modes below do NOT work.
54 | # tv_fn.InterpolationMode.BOX,
55 | # tv_fn.InterpolationMode.HAMMING,
56 | # tv_fn.InterpolationMode.LANCZOS,
57 | ],
58 | [
59 | 0,
60 | 243,
61 | (23, 54, 211),
62 | (213, 1, 2, 245),
63 | ],
64 | [
65 | # (3,), # This channel does NOT work
66 | # (4,), # This channel does NOT work
67 | (None,), # Random non-reproducible error observed
68 | ],
69 | ["CASCADE", "CONSUME"],
70 | )
71 | ),
72 | )
73 | def test_tx_on_PIL_images(
74 | repetition: int,
75 | size: t.Tuple[int],
76 | distortion_scale: float,
77 | p: float,
78 | interpolation_mode: tv_fn.InterpolationMode,
79 | fill: t.Union[int, float, t.Sequence[int], t.Sequence[float]],
80 | channels: t.Tuple[int],
81 | tx_mode: str,
82 | ) -> None:
83 |
84 | try:
85 |
86 | if isinstance(fill, float) or isinstance(fill, int):
87 | clean_fill = fill
88 | elif len(fill) == channels[0]:
89 | clean_fill = fill
90 | else:
91 | clean_fill = 0
92 |
93 | if channels == (None,):
94 | channels = ()
95 |
96 | size = size + channels
97 |
98 | tx = ptx.RandomPerspective(
99 | distortion_scale=distortion_scale,
100 | interpolation=interpolation_mode,
101 | fill=clean_fill,
102 | p=p,
103 | tx_mode=tx_mode,
104 | )
105 |
106 | img = Image.fromarray(
107 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
108 | )
109 |
110 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
111 |
112 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
113 | assert len(params_1) - len(orig_params) == tx.param_count
114 |
115 | orig_params = ()
116 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
117 | aug_3, params_3 = tx.consume_transform(img, params_2)
118 |
119 | assert orig_params == params_3
120 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
121 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
122 |
123 | id_params = tx.get_default_params(img=img, processed=True)
124 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
125 | assert id_aug_rem_params == ()
126 | assert not ImageChops.difference(id_aug_img, img).getbbox()
127 |
128 | except torch._C._LinAlgError as e: # noqa
129 | raise RuntimeError(f"ERROR | `torch._C._LinAlgError` raised in the test: {e}")
130 |
131 | except Exception as e:
132 | raise RuntimeError(f"ERROR | Issue raised in the test: {e}")
133 |
134 |
135 | @pytest.mark.parametrize(
136 | "repetition, size, distortion_scale, p, interpolation_mode, fill, channels, tx_mode",
137 | list(
138 | product(
139 | [idx for idx in range(2)],
140 | [
141 | (
142 | 224,
143 | 224,
144 | ),
145 | ],
146 | [0.0, 0.5, 1.0],
147 | [0.0, 0.5, 1.0],
148 | [
149 | tv_fn.InterpolationMode.NEAREST,
150 | tv_fn.InterpolationMode.BILINEAR,
151 | # # The modes below are NOT supported.
152 | # tv_fn.InterpolationMode.BICUBIC,
153 | # tv_fn.InterpolationMode.BOX,
154 | # tv_fn.InterpolationMode.HAMMING,
155 | # tv_fn.InterpolationMode.LANCZOS,
156 | ],
157 | [
158 | 0.0,
159 | 128.0 / 255,
160 | (23.0 / 255, 54.0 / 255, 211.0 / 255),
161 | (213.0 / 255, 1.0 / 255, 2.0 / 255, 245.0 / 255),
162 | ],
163 | [
164 | (3,),
165 | (4,),
166 | # (None,), # Leads to errors in general. Avoiding.
167 | ],
168 | ["CASCADE", "CONSUME"],
169 | )
170 | ),
171 | )
172 | def test_tx_on_torch_tensors(
173 | repetition: int,
174 | size: t.Tuple[int],
175 | distortion_scale: float,
176 | p: float,
177 | interpolation_mode: tv_fn.InterpolationMode,
178 | fill: t.Union[int, float, t.Sequence[int], t.Sequence[float]],
179 | channels: t.Tuple[int],
180 | tx_mode: str,
181 | ) -> None:
182 | """
183 | Observed issue--
184 | ```
185 | torch._C._LinAlgError: torch.linalg.lstsq:
186 | The least squares solution could not be computed because
187 | the input matrix does not have full rank (error code: 8).
188 | ```
189 | Currently, the work-around is to ignore this error if it pops up.
190 | """
191 |
192 | try:
193 |
194 | if isinstance(fill, float) or isinstance(fill, int):
195 | clean_fill = fill
196 | elif len(fill) == channels[0]:
197 | clean_fill = fill
198 | else:
199 | clean_fill = 0
200 |
201 | if channels == (None,):
202 | channels = ()
203 |
204 | size = channels + size
205 |
206 | tx = ptx.RandomPerspective(
207 | distortion_scale=distortion_scale,
208 | interpolation=interpolation_mode,
209 | fill=clean_fill,
210 | p=p,
211 | tx_mode=tx_mode,
212 | )
213 |
214 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size))
215 |
216 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
217 |
218 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
219 | assert len(params_1) - len(orig_params) == tx.param_count
220 |
221 | orig_params = ()
222 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
223 | aug_3, params_3 = tx.consume_transform(img, params_2)
224 |
225 | assert orig_params == params_3
226 |
227 | if interpolation_mode != tv_fn.InterpolationMode.BILINEAR:
228 | torch.all(torch.isclose(aug_2, aug_3, rtol=1e-02, atol=1e-05))
229 | assert all(
230 | [isinstance(elt, float) or isinstance(elt, int) for elt in params_2]
231 | )
232 |
233 | id_params = tx.get_default_params(img=img, processed=True)
234 | id_aug_img, id_aug_rem_params = tx.consume_transform(
235 | img=img, params=id_params
236 | )
237 | assert id_aug_rem_params == ()
238 | assert torch.all(torch.isclose(id_aug_img, img, rtol=1e-02, atol=1e-05))
239 |
240 | except torch._C._LinAlgError as e: # noqa
241 | raise RuntimeError(f"ERROR | torch._C._LinAlgError raised in the test: {e}")
242 |
243 | except Exception as e:
244 | raise RuntimeError(f"ERROR | Issue raised in the test: {e}")
245 |
246 |
247 | # Main.
248 | if __name__ == "__main__":
249 | pass
250 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_RandomPosterize.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 |
25 | import numpy as np
26 |
27 | import PIL.Image as Image
28 | import PIL.ImageChops as ImageChops
29 |
30 | from itertools import product
31 |
32 | import typing as t
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "size, bits, p, channels, tx_mode",
37 | list(
38 | product(
39 | [
40 | (
41 | 224,
42 | 224,
43 | ),
44 | ],
45 | [num_bits for num_bits in range(9)],
46 | [0.0, 1.0, 0.5],
47 | [
48 | (3,),
49 | # # Not supported.
50 | # (4,),
51 | # (None,),
52 | ],
53 | ["CASCADE", "CONSUME"],
54 | )
55 | ),
56 | )
57 | def test_tx_on_PIL_images(
58 | size: t.Tuple[int],
59 | bits: int,
60 | p: float,
61 | channels: t.Tuple[int],
62 | tx_mode: str,
63 | ) -> None:
64 | if channels == (None,):
65 | channels = ()
66 |
67 | size = size + channels
68 |
69 | tx = ptx.RandomPosterize(
70 | bits=bits,
71 | p=p,
72 | tx_mode=tx_mode,
73 | )
74 |
75 | img = Image.fromarray(
76 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
77 | )
78 |
79 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
80 |
81 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
82 | assert len(params_1) - len(orig_params) == tx.param_count
83 |
84 | orig_params = ()
85 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
86 | aug_3, params_3 = tx.consume_transform(img, params_2)
87 |
88 | assert orig_params == params_3
89 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
90 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
91 |
92 | id_params = tx.get_default_params(img=img, processed=True)
93 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
94 | assert id_aug_rem_params == ()
95 | assert not ImageChops.difference(id_aug_img, img).getbbox()
96 |
97 |
98 | @pytest.mark.parametrize(
99 | "size, bits, p, channels, tx_mode",
100 | list(
101 | product(
102 | [
103 | (
104 | 224,
105 | 224,
106 | ),
107 | ],
108 | [num_bits for num_bits in range(9)],
109 | [0.0, 1.0, 0.5],
110 | [
111 | (3,),
112 | # # Not supported.
113 | # (4,),
114 | # (None,),
115 | ],
116 | ["CASCADE", "CONSUME"],
117 | )
118 | ),
119 | )
120 | def test_tx_on_torch_tensors(
121 | size: t.Tuple[int],
122 | bits: int,
123 | p: float,
124 | channels: t.Tuple[int],
125 | tx_mode: str,
126 | ) -> None:
127 | if channels == (None,):
128 | channels = ()
129 |
130 | size = channels + size
131 |
132 | tx = ptx.RandomPosterize(
133 | bits=bits,
134 | p=p,
135 | tx_mode=tx_mode,
136 | )
137 |
138 | img = torch.from_numpy(
139 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
140 | )
141 |
142 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
143 |
144 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
145 | assert len(params_1) - len(orig_params) == tx.param_count
146 |
147 | orig_params = ()
148 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
149 | aug_3, params_3 = tx.consume_transform(img, params_2)
150 |
151 | assert orig_params == params_3
152 | assert torch.all(torch.eq(aug_2, aug_3))
153 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
154 |
155 | id_params = tx.get_default_params(img=img, processed=True)
156 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
157 | assert id_aug_rem_params == ()
158 | assert torch.all(torch.eq(id_aug_img, img))
159 |
160 |
161 | # Main.
162 | if __name__ == "__main__":
163 | pass
164 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_RandomResizedCrop.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 | import torchvision.transforms.functional as tv_fn
23 |
24 | import parameterized_transforms.transforms as ptx
25 |
26 | import numpy as np
27 |
28 | import PIL.Image as Image
29 | import PIL.ImageChops as ImageChops
30 |
31 | from itertools import product
32 |
33 | import typing as t
34 |
35 |
36 | @pytest.mark.parametrize(
37 | "size, to_size, scale, ratio, interpolation_mode, channels, tx_mode",
38 | list(
39 | product(
40 | [
41 | (
42 | 224,
43 | 224,
44 | ),
45 | ],
46 | [
47 | 224,
48 | 32,
49 | [224, 224],
50 | [32, 32],
51 | (224, 224),
52 | (32, 32),
53 | [224, 32],
54 | (24, 224),
55 | ],
56 | [
57 | (0.33, 0.67),
58 | [0.08, 1.0],
59 | ],
60 | [
61 | (0.08, 1.0),
62 | [0.08, 1.0],
63 | ],
64 | [
65 | tv_fn.InterpolationMode.NEAREST,
66 | tv_fn.InterpolationMode.BILINEAR,
67 | tv_fn.InterpolationMode.BICUBIC,
68 | tv_fn.InterpolationMode.BOX,
69 | tv_fn.InterpolationMode.HAMMING,
70 | tv_fn.InterpolationMode.LANCZOS,
71 | ],
72 | [(3,), (4,), (None,)],
73 | ["CASCADE", "CONSUME"],
74 | )
75 | ),
76 | )
77 | def test_tx_on_PIL_images(
78 | size: t.Tuple[int],
79 | to_size: t.Union[int, t.List[int], t.Tuple[int, int]],
80 | scale: t.Union[t.List[float], t.Tuple[float, float]],
81 | ratio: t.Union[t.List[float], t.Tuple[float, float]],
82 | interpolation_mode: tv_fn.InterpolationMode,
83 | channels: t.Tuple[int],
84 | tx_mode: str,
85 | ) -> None:
86 | if channels == (None,):
87 | channels = ()
88 |
89 | size = size + channels
90 |
91 | tx = ptx.RandomResizedCrop(
92 | size=to_size,
93 | scale=scale,
94 | ratio=ratio,
95 | interpolation=interpolation_mode,
96 | tx_mode=tx_mode,
97 | )
98 |
99 | img = Image.fromarray(
100 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
101 | )
102 |
103 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
104 |
105 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
106 | assert len(params_1) - len(orig_params) == tx.param_count
107 |
108 | orig_params = ()
109 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
110 | aug_3, params_3 = tx.consume_transform(img, params_2)
111 |
112 | assert orig_params == params_3
113 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
114 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
115 |
116 | id_params = tx.get_default_params(img=img, processed=True)
117 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
118 | if img.size == id_aug_img.size:
119 | assert id_aug_rem_params == ()
120 | assert not ImageChops.difference(id_aug_img, img).getbbox()
121 |
122 |
123 | @pytest.mark.parametrize(
124 | "size, to_size, scale, ratio, interpolation_mode, channels, tx_mode",
125 | list(
126 | product(
127 | [
128 | (
129 | 224,
130 | 224,
131 | ),
132 | ],
133 | [
134 | 224,
135 | 32,
136 | [224, 224],
137 | [32, 32],
138 | (224, 224),
139 | (32, 32),
140 | [224, 32],
141 | (24, 224),
142 | ],
143 | [
144 | (0.33, 0.67),
145 | [0.08, 1.0],
146 | ],
147 | [
148 | (0.08, 1.0),
149 | [0.08, 1.0],
150 | ],
151 | [
152 | tv_fn.InterpolationMode.NEAREST,
153 | tv_fn.InterpolationMode.BILINEAR,
154 | tv_fn.InterpolationMode.BICUBIC,
155 | # # These modes below are NOT supported for tensors.
156 | # tv_fn.InterpolationMode.BOX,
157 | # tv_fn.InterpolationMode.HAMMING,
158 | # tv_fn.InterpolationMode.LANCZOS,
159 | ],
160 | [(3,), (4,)],
161 | ["CASCADE", "CONSUME"],
162 | )
163 | ),
164 | )
165 | def test_tx_on_torch_tensors(
166 | size: t.Tuple[int],
167 | to_size: t.Union[int, t.List[int], t.Tuple[int, int]],
168 | scale: t.Union[t.List[float], t.Tuple[float, float]],
169 | ratio: t.Union[t.List[float], t.Tuple[float, float]],
170 | interpolation_mode: tv_fn.InterpolationMode,
171 | channels: t.Tuple[int],
172 | tx_mode: str,
173 | ) -> None:
174 | if channels == (None,):
175 | channels = ()
176 |
177 | size = channels + size
178 |
179 | tx = ptx.RandomResizedCrop(
180 | size=to_size,
181 | scale=scale,
182 | ratio=ratio,
183 | interpolation=interpolation_mode,
184 | tx_mode=tx_mode,
185 | )
186 |
187 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size))
188 |
189 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
190 |
191 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
192 | assert len(params_1) - len(orig_params) == tx.param_count
193 |
194 | orig_params = ()
195 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
196 | aug_3, params_3 = tx.consume_transform(img, params_2)
197 |
198 | assert orig_params == params_3
199 | assert torch.all(torch.eq(aug_2, aug_3))
200 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
201 |
202 | id_params = tx.get_default_params(img=img, processed=True)
203 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
204 | if img.shape == id_aug_img.shape:
205 | assert id_aug_rem_params == ()
206 | assert torch.all(torch.eq(id_aug_img, img))
207 |
208 |
209 | # Main.
210 | if __name__ == "__main__":
211 | pass
212 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_RandomRotation.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 | import torchvision.transforms.functional as tv_fn
23 |
24 | import parameterized_transforms.transforms as ptx
25 |
26 | import numpy as np
27 |
28 | import PIL.Image as Image
29 | import PIL.ImageChops as ImageChops
30 |
31 | from itertools import product
32 |
33 | import typing as t
34 |
35 |
36 | @pytest.mark.parametrize(
37 | "size, degrees, interpolation_mode, expand, fill, channels, tx_mode",
38 | list(
39 | product(
40 | [
41 | (
42 | 224,
43 | 224,
44 | ),
45 | ],
46 | [
47 | 10,
48 | 0.0,
49 | [-10, 10.0],
50 | [3.14, 2 * 3.14],
51 | ],
52 | [
53 | tv_fn.InterpolationMode.NEAREST,
54 | tv_fn.InterpolationMode.BILINEAR,
55 | tv_fn.InterpolationMode.BICUBIC,
56 | tv_fn.InterpolationMode.BOX,
57 | tv_fn.InterpolationMode.LANCZOS,
58 | tv_fn.InterpolationMode.HAMMING,
59 | ],
60 | [
61 | True,
62 | False,
63 | ],
64 | [
65 | 0,
66 | 243,
67 | (23, 54, 211),
68 | (213, 1, 2, 245),
69 | ],
70 | [(3,), (4,), (None,)],
71 | ["CASCADE", "CONSUME"],
72 | )
73 | ),
74 | )
75 | def test_tx_on_PIL_images(
76 | size: t.Tuple[int],
77 | degrees: t.Union[
78 | float,
79 | int,
80 | t.List[float],
81 | t.List[int],
82 | t.Tuple[float, float],
83 | t.Tuple[int, int],
84 | ],
85 | interpolation_mode: tv_fn.InterpolationMode,
86 | expand: bool,
87 | fill: t.Union[int, float, t.Sequence[int], t.Sequence[float]],
88 | channels: t.Tuple[int],
89 | tx_mode: str,
90 | ) -> None:
91 |
92 | if isinstance(fill, float) or isinstance(fill, int):
93 | clean_fill = fill
94 | elif len(fill) == channels[0]:
95 | clean_fill = fill
96 | else:
97 | clean_fill = 0
98 |
99 | if channels == (None,):
100 | channels = ()
101 |
102 | size = size + channels
103 |
104 | tx = ptx.RandomRotation(
105 | degrees=degrees,
106 | interpolation=interpolation_mode,
107 | expand=expand,
108 | center=None,
109 | fill=clean_fill,
110 | tx_mode=tx_mode,
111 | )
112 |
113 | img = Image.fromarray(
114 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
115 | )
116 |
117 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
118 |
119 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
120 | assert len(params_1) - len(orig_params) == tx.param_count
121 |
122 | orig_params = ()
123 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
124 | aug_3, params_3 = tx.consume_transform(img, params_2)
125 |
126 | assert orig_params == params_3
127 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
128 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
129 |
130 | id_params = tx.get_default_params(img=img, processed=True)
131 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
132 | assert id_aug_rem_params == ()
133 | assert not ImageChops.difference(id_aug_img, img).getbbox()
134 |
135 |
136 | @pytest.mark.parametrize(
137 | "size, degrees, interpolation_mode, expand, fill, channels, tx_mode",
138 | list(
139 | product(
140 | [
141 | (
142 | 224,
143 | 224,
144 | ),
145 | ],
146 | [
147 | 10,
148 | 0.0,
149 | [-10, 10.0],
150 | [3.14, 2 * 3.14],
151 | # -1000., # This raises error; if single number, it MUST be positive
152 | ],
153 | [
154 | tv_fn.InterpolationMode.NEAREST,
155 | tv_fn.InterpolationMode.BILINEAR,
156 | tv_fn.InterpolationMode.BICUBIC,
157 | tv_fn.InterpolationMode.BOX,
158 | tv_fn.InterpolationMode.LANCZOS,
159 | tv_fn.InterpolationMode.HAMMING,
160 | ],
161 | [
162 | True,
163 | False,
164 | ],
165 | [
166 | 0.0,
167 | 128.0 / 255,
168 | (23.0 / 255, 54.0 / 255, 211.0 / 255),
169 | (213.0 / 255, 1.0 / 255, 2.0 / 255, 245.0 / 255),
170 | ],
171 | [
172 | (3,),
173 | (4,),
174 | ],
175 | ["CASCADE", "CONSUME"],
176 | )
177 | ),
178 | )
179 | def test_tx_on_torch_tensors(
180 | size: t.Tuple[int],
181 | degrees: t.Union[
182 | float,
183 | int,
184 | t.List[float],
185 | t.List[int],
186 | t.Tuple[float, float],
187 | t.Tuple[int, int],
188 | ],
189 | interpolation_mode: tv_fn.InterpolationMode,
190 | expand: bool,
191 | fill: t.Union[int, float, t.Sequence[int], t.Sequence[float]],
192 | channels: t.Tuple[int],
193 | tx_mode: str,
194 | ) -> None:
195 |
196 | if isinstance(fill, float) or isinstance(fill, int):
197 | clean_fill = fill
198 | elif len(fill) == channels[0]:
199 | clean_fill = fill
200 | else:
201 | clean_fill = 0
202 |
203 | if channels == (None,):
204 | channels = ()
205 |
206 | size = channels + size
207 |
208 | tx = ptx.RandomRotation(
209 | degrees=degrees,
210 | interpolation=interpolation_mode,
211 | expand=expand,
212 | center=None,
213 | fill=clean_fill,
214 | tx_mode=tx_mode,
215 | )
216 |
217 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size))
218 |
219 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
220 |
221 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
222 | assert len(params_1) - len(orig_params) == tx.param_count
223 |
224 | orig_params = ()
225 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
226 | aug_3, params_3 = tx.consume_transform(img, params_2)
227 |
228 | assert orig_params == params_3
229 | assert torch.all(torch.eq(aug_2, aug_3))
230 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
231 |
232 | id_params = tx.get_default_params(img=img, processed=True)
233 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
234 | assert id_aug_rem_params == ()
235 | assert torch.all(torch.eq(id_aug_img, img))
236 |
237 |
238 | # Main.
239 | if __name__ == "__main__":
240 | pass
241 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_RandomSolarize.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 |
25 | import numpy as np
26 |
27 | import PIL.Image as Image
28 | import PIL.ImageChops as ImageChops
29 |
30 | from itertools import product
31 |
32 | import typing as t
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "size, threshold, p, channels, tx_mode",
37 | list(
38 | product(
39 | [
40 | (
41 | 224,
42 | 224,
43 | ),
44 | ],
45 | [num_bits for num_bits in range(9)],
46 | [0, 0.5, 1.0],
47 | [
48 | (3,),
49 | # # Not supported.
50 | # (4,),
51 | # (None,),
52 | ],
53 | ["CASCADE", "CONSUME"],
54 | )
55 | ),
56 | )
57 | def test_tx_on_PIL_images(
58 | size: t.Tuple[int],
59 | threshold: int,
60 | p: float,
61 | channels: t.Tuple[int],
62 | tx_mode: str,
63 | ) -> None:
64 | if channels == (None,):
65 | channels = ()
66 |
67 | size = size + channels
68 |
69 | tx = ptx.RandomSolarize(
70 | threshold=threshold,
71 | p=p,
72 | tx_mode=tx_mode,
73 | )
74 |
75 | img = Image.fromarray(
76 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
77 | )
78 |
79 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
80 |
81 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
82 | assert len(params_1) - len(orig_params) == tx.param_count
83 |
84 | orig_params = ()
85 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
86 | aug_3, params_3 = tx.consume_transform(img, params_2)
87 |
88 | assert orig_params == params_3
89 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
90 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
91 |
92 | id_params = tx.get_default_params(img=img, processed=True)
93 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
94 | assert id_aug_rem_params == ()
95 | assert not ImageChops.difference(id_aug_img, img).getbbox()
96 |
97 |
98 | @pytest.mark.parametrize(
99 | "size, threshold, p, channels, tx_mode",
100 | list(
101 | product(
102 | [
103 | (
104 | 224,
105 | 224,
106 | ),
107 | (
108 | 32,
109 | 32,
110 | ),
111 | (
112 | 28,
113 | 28,
114 | ),
115 | ],
116 | [num_bits for num_bits in range(9)],
117 | # Even weird values work!
118 | [0, 0.5, 1.0, 255,],
119 | [
120 | (3,),
121 | # # Not supported.
122 | # (4,),
123 | # (None,),
124 | ],
125 | ["CASCADE", "CONSUME"],
126 | )
127 | ),
128 | )
129 | def test_tx_on_torch_tensors(
130 | size: t.Tuple[int],
131 | threshold: float,
132 | p: float,
133 | channels: t.Tuple[int],
134 | tx_mode: str,
135 | ) -> None:
136 | if channels == (None,):
137 | channels = ()
138 |
139 | size = channels + size
140 |
141 | tx = ptx.RandomSolarize(
142 | threshold=threshold,
143 | p=p,
144 | tx_mode=tx_mode,
145 | )
146 |
147 | img = torch.from_numpy(
148 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
149 | )
150 |
151 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
152 |
153 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
154 | assert len(params_1) - len(orig_params) == tx.param_count
155 |
156 | orig_params = ()
157 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
158 | aug_3, params_3 = tx.consume_transform(img, params_2)
159 |
160 | assert orig_params == params_3
161 | assert torch.all(torch.eq(aug_2, aug_3))
162 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
163 |
164 | id_params = tx.get_default_params(img=img, processed=True)
165 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
166 | assert id_aug_rem_params == ()
167 | assert torch.all(torch.eq(id_aug_img, img))
168 |
169 |
170 | # Main.
171 | if __name__ == "__main__":
172 | pass
173 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_RandomVerticalFlip.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 |
25 | import numpy as np
26 |
27 | import PIL.Image as Image
28 | import PIL.ImageChops as ImageChops
29 |
30 | from itertools import product
31 |
32 | import typing as t
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "size, p, channels, tx_mode",
37 | list(
38 | product(
39 | [
40 | (
41 | 224,
42 | 224,
43 | ),
44 | ],
45 | [0.0, 0.5, 1.0],
46 | [(3,), (4,), (None,)],
47 | ["CASCADE", "CONSUME"],
48 | )
49 | ),
50 | )
51 | def test_tx_on_PIL_images(
52 | size: t.Tuple[int],
53 | p: float,
54 | channels: t.Tuple[int],
55 | tx_mode: str,
56 | ) -> None:
57 |
58 | #
59 | class DummyContext(object):
60 | def __enter__(self):
61 | pass
62 |
63 | def __exit__(self, exc_type, exc_val, exc_tb):
64 | pass
65 |
66 | if channels == (None,):
67 | channels = ()
68 |
69 | size = size + channels
70 |
71 | tx = ptx.RandomVerticalFlip(
72 | p=p,
73 | tx_mode=tx_mode,
74 | )
75 |
76 | img = Image.fromarray(
77 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
78 | )
79 |
80 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
81 |
82 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
83 | assert len(params_1) - len(orig_params) == tx.param_count
84 |
85 | orig_params = ()
86 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
87 | aug_3, params_3 = tx.consume_transform(img, params_2)
88 |
89 | assert orig_params == params_3
90 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
91 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
92 |
93 | if params_2[-1] == 0:
94 | assert not ImageChops.difference(aug_2, img).getbbox()
95 |
96 | id_params = tx.get_default_params(img=img, processed=True)
97 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
98 | assert id_aug_rem_params == ()
99 | assert not ImageChops.difference(id_aug_img, img).getbbox()
100 |
101 |
102 | @pytest.mark.parametrize(
103 | "size, p, channels, tx_mode",
104 | list(
105 | product(
106 | [
107 | (
108 | 224,
109 | 224,
110 | ),
111 | ],
112 | [0.0, 0.5, 1.0],
113 | [(3,), (4,), (None,)],
114 | ["CASCADE", "CONSUME"],
115 | )
116 | ),
117 | )
118 | def test_tx_on_torch_tensors(
119 | size: t.Tuple[int],
120 | p: float,
121 | channels: t.Tuple[int],
122 | tx_mode: str,
123 | ) -> None:
124 |
125 | #
126 | class DummyContext(object):
127 | def __enter__(self):
128 | pass
129 |
130 | def __exit__(self, exc_type, exc_val, exc_tb):
131 | pass
132 |
133 | if channels == (None,):
134 | channels = ()
135 |
136 | size = channels + size
137 |
138 | tx = ptx.RandomVerticalFlip(
139 | p=p,
140 | tx_mode=tx_mode,
141 | )
142 |
143 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size))
144 |
145 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
146 |
147 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
148 | assert len(params_1) - len(orig_params) == tx.param_count
149 |
150 | orig_params = ()
151 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
152 | aug_3, params_3 = tx.consume_transform(img, params_2)
153 |
154 | assert orig_params == params_3
155 | assert torch.all(torch.eq(aug_2, aug_3))
156 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
157 |
158 | if params_2[-1] == 0:
159 | assert torch.all(torch.eq(aug_2, img))
160 |
161 | id_params = tx.get_default_params(img=img, processed=True)
162 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
163 | assert id_aug_rem_params == ()
164 | assert torch.all(torch.eq(id_aug_img, img))
165 |
166 |
167 | # Main.
168 | if __name__ == "__main__":
169 | pass
170 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_Resize.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 | import torchvision.transforms.functional as tv_fn
23 |
24 | import parameterized_transforms.transforms as ptx
25 |
26 | import numpy as np
27 |
28 | import PIL.Image as Image
29 | import PIL.ImageChops as ImageChops
30 |
31 | from itertools import product
32 |
33 | import typing as t
34 |
35 |
36 | @pytest.mark.parametrize(
37 | "size, to_size, interpolation_mode, channels, tx_mode",
38 | list(
39 | product(
40 | [
41 | (
42 | 224,
43 | 224,
44 | ),
45 | ],
46 | [
47 | [224, 224],
48 | [32, 32],
49 | (224, 224),
50 | (32, 32),
51 | (224, 32),
52 | [32, 28],
53 | (224,),
54 | (32,),
55 | [
56 | 224,
57 | ],
58 | [
59 | 32,
60 | ],
61 | 224,
62 | 32,
63 | ],
64 | [
65 | tv_fn.InterpolationMode.NEAREST,
66 | tv_fn.InterpolationMode.BILINEAR,
67 | tv_fn.InterpolationMode.BICUBIC,
68 | tv_fn.InterpolationMode.BOX,
69 | tv_fn.InterpolationMode.LANCZOS,
70 | tv_fn.InterpolationMode.HAMMING,
71 | ],
72 | [
73 | (None,),
74 | (3,),
75 | (4,),
76 | ],
77 | ["CASCADE", "CONSUME"],
78 | )
79 | ),
80 | )
81 | def test_tx_on_PIL_images(
82 | size: t.Tuple[int],
83 | to_size: t.Union[t.Tuple[int], t.List[int], int],
84 | interpolation_mode: tv_fn.InterpolationMode,
85 | channels: t.Tuple[int],
86 | tx_mode: str,
87 | ) -> None:
88 |
89 | if channels == (None,):
90 | channels = ()
91 |
92 | size = size + channels
93 |
94 | tx = ptx.Resize(
95 | size=to_size,
96 | interpolation=interpolation_mode,
97 | max_size=None,
98 | antialias=None,
99 | tx_mode=tx_mode,
100 | )
101 |
102 | img = Image.fromarray(
103 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
104 | )
105 |
106 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
107 |
108 | aug_1, params_1 = tx(img, orig_params)
109 |
110 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
111 |
112 | aug_3, params_3 = tx.consume_transform(img, orig_params)
113 |
114 | assert orig_params == params_1
115 | assert orig_params == params_2
116 | assert orig_params == params_3
117 |
118 | assert not ImageChops.difference(aug_1, aug_2).getbbox()
119 | assert not ImageChops.difference(aug_1, aug_3).getbbox()
120 |
121 | id_params = tx.get_default_params(img=img, processed=True)
122 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
123 | if img.size == id_aug_img.size:
124 | assert id_aug_rem_params == ()
125 | assert not ImageChops.difference(id_aug_img, img).getbbox()
126 |
127 |
128 | @pytest.mark.parametrize(
129 | "size, to_size, interpolation_mode, channels, tx_mode",
130 | list(
131 | product(
132 | [
133 | (
134 | 224,
135 | 224,
136 | ),
137 | ],
138 | [
139 | [224, 224],
140 | [32, 32],
141 | (224, 224),
142 | (32, 32),
143 | (224, 32),
144 | [32, 28],
145 | (224,),
146 | (32,),
147 | [
148 | 224,
149 | ],
150 | [
151 | 32,
152 | ],
153 | 224,
154 | 32,
155 | ],
156 | [
157 | tv_fn.InterpolationMode.NEAREST,
158 | tv_fn.InterpolationMode.BILINEAR,
159 | tv_fn.InterpolationMode.BICUBIC,
160 | # # The below interpolation modes do not work well.
161 | # tv_fn.InterpolationMode.BOX,
162 | # tv_fn.InterpolationMode.LANCZOS,
163 | # tv_fn.InterpolationMode.HAMMING,
164 | ],
165 | [
166 | (3,),
167 | (4,),
168 | # (None, ) # This raises errors
169 | ],
170 | ["CASCADE", "CONSUME"],
171 | )
172 | ),
173 | )
174 | def test_tx_on_torch_tensors(
175 | size: t.Tuple[int],
176 | to_size: t.Union[t.Tuple[int], t.List[int], int],
177 | interpolation_mode: tv_fn.InterpolationMode,
178 | channels: t.Tuple[int],
179 | tx_mode: str,
180 | ) -> None:
181 | class DummyContext(object):
182 | def __enter__(self):
183 | pass
184 |
185 | def __exit__(self, exc_type, exc_val, exc_tb):
186 | pass
187 |
188 | if channels == (None,):
189 | channels = ()
190 |
191 | size = channels + size
192 |
193 | tx = ptx.Resize(
194 | size=to_size,
195 | interpolation=interpolation_mode,
196 | max_size=None,
197 | antialias=None,
198 | tx_mode=tx_mode,
199 | )
200 |
201 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size))
202 |
203 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
204 |
205 | if interpolation_mode in [
206 | tv_fn.InterpolationMode.BOX,
207 | tv_fn.InterpolationMode.LANCZOS,
208 | tv_fn.InterpolationMode.HAMMING,
209 | ]:
210 | context = pytest.raises(ValueError)
211 | else:
212 | context = DummyContext()
213 |
214 | with context:
215 |
216 | aug_1, params_1 = tx(img, orig_params)
217 |
218 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
219 |
220 | aug_3, params_3 = tx.consume_transform(img, orig_params)
221 |
222 | assert orig_params == params_1
223 | assert orig_params == params_2
224 | assert orig_params == params_3
225 |
226 | assert torch.all(torch.eq(aug_1, aug_2))
227 | assert torch.all(torch.eq(aug_1, aug_3))
228 |
229 | id_params = tx.get_default_params(img=img, processed=True)
230 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
231 | if img.shape == id_aug_img.shape:
232 | assert id_aug_rem_params == ()
233 | assert torch.all(torch.eq(id_aug_img, img))
234 |
235 |
236 | # Main.
237 | if __name__ == "__main__":
238 | pass
239 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_TenCrop.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 |
25 | import numpy as np
26 |
27 | import PIL.Image as Image
28 | import PIL.ImageChops as ImageChops
29 |
30 | from itertools import product
31 |
32 | import typing as t
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "size, to_size, channels, tx_mode",
37 | list(
38 | product(
39 | [
40 | (
41 | 224,
42 | 224,
43 | ),
44 | ],
45 | [
46 | # 224, # This gives error as crop size > image size is possible
47 | 28,
48 | [24, 24],
49 | (24, 24),
50 | [24, 3],
51 | [3, 8],
52 | (8, 24),
53 | ],
54 | [(3,), (4,), (None,)],
55 | ["CASCADE", "CONSUME"],
56 | )
57 | ),
58 | )
59 | def test_tx_on_PIL_images(
60 | size: t.Tuple[int],
61 | to_size: t.Union[int, t.List[int], t.Tuple[int, int]],
62 | channels: t.Tuple[int],
63 | tx_mode: str,
64 | ) -> None:
65 | if channels == (None,):
66 | channels = ()
67 |
68 | size = size + channels
69 |
70 | tx = ptx.TenCrop(
71 | size=to_size,
72 | tx_mode=tx_mode,
73 | )
74 |
75 | img = Image.fromarray(
76 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
77 | )
78 |
79 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
80 |
81 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
82 | assert len(params_1) - len(orig_params) == tx.param_count
83 |
84 | orig_params = ()
85 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
86 | aug_3, params_3 = tx.consume_transform(img, params_2)
87 |
88 | assert orig_params == params_3
89 | assert all(
90 | [
91 | not ImageChops.difference(aug_2_component, aug_3_component).getbbox()
92 | for aug_2_component, aug_3_component in zip(aug_2, aug_3)
93 | ]
94 | )
95 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
96 |
97 | id_params = tx.get_default_params(img=img, processed=True)
98 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
99 | if img.size == id_aug_img[0].size:
100 | assert all(
101 | not ImageChops.difference(img, an_id_aug_img).getbbox()
102 | for an_id_aug_img in id_aug_img[:5]
103 | )
104 | assert id_aug_rem_params == ()
105 |
106 |
107 | @pytest.mark.parametrize(
108 | "size, to_size, channels, tx_mode",
109 | list(
110 | product(
111 | [
112 | (
113 | 224,
114 | 224,
115 | ),
116 | ],
117 | [
118 | # 224, # This gives error as crop size > image size is possible
119 | 28,
120 | [24, 24],
121 | (24, 24),
122 | [24, 3],
123 | [3, 8],
124 | (8, 24),
125 | ],
126 | [(3,), (4,), (None,)],
127 | ["CASCADE", "CONSUME"],
128 | )
129 | ),
130 | )
131 | def test_tx_on_torch_tensors(
132 | size: t.Tuple[int],
133 | to_size: t.Union[int, t.List[int], t.Tuple[int, int]],
134 | channels: t.Tuple[int],
135 | tx_mode: str,
136 | ) -> None:
137 | if channels == (None,):
138 | channels = ()
139 |
140 | size = channels + size
141 |
142 | tx = ptx.TenCrop(
143 | size=to_size,
144 | tx_mode=tx_mode,
145 | )
146 |
147 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size))
148 |
149 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
150 |
151 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
152 | assert len(params_1) - len(orig_params) == tx.param_count
153 |
154 | orig_params = ()
155 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
156 | aug_3, params_3 = tx.consume_transform(img, params_2)
157 |
158 | assert orig_params == params_3
159 | assert all(
160 | [
161 | torch.all(torch.eq(aug_2_component, aug_3_component))
162 | for aug_2_component, aug_3_component in zip(aug_2, aug_3)
163 | ]
164 | )
165 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
166 |
167 | id_params = tx.get_default_params(img=img, processed=True)
168 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
169 | if img.shape == id_aug_img[0].shape:
170 | assert all(
171 | torch.all(torch.eq(img, an_id_aug_img)) for an_id_aug_img in id_aug_img[:5]
172 | )
173 | assert id_aug_rem_params == ()
174 |
175 |
176 | # Main.
177 | if __name__ == "__main__":
178 | pass
179 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_ToPILImage.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | # Dependencies.
18 | import pytest
19 | import torch
20 |
21 | import parameterized_transforms.transforms as ptx
22 |
23 | import numpy as np
24 |
25 | import PIL.Image as Image
26 | import PIL.ImageChops as ImageChops
27 |
28 | from itertools import product
29 |
30 | import typing as t
31 |
32 |
33 | @pytest.mark.parametrize(
34 | "size, channels, tx_mode",
35 | list(
36 | product(
37 | [
38 | (
39 | 224,
40 | 224,
41 | ),
42 | (
43 | 32,
44 | 32,
45 | ),
46 | (
47 | 28,
48 | 28,
49 | ),
50 | ],
51 | [
52 | (None,),
53 | (1,),
54 | (3,),
55 | (4,),
56 | ],
57 | ["CASCADE", "CONSUME"],
58 | )
59 | ),
60 | )
61 | def test_tx_on_PIL_images(
62 | size: t.Tuple[int],
63 | channels: t.Tuple[int],
64 | tx_mode: str,
65 | ) -> None:
66 | with pytest.raises(TypeError):
67 | if channels == (None,):
68 | channels = ()
69 |
70 | size = size + channels
71 |
72 | tx = ptx.ToPILImage()
73 |
74 | img = Image.fromarray(
75 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
76 | )
77 |
78 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
79 |
80 | aug_1, params_1 = tx(img, orig_params)
81 |
82 |
83 | @pytest.mark.parametrize(
84 | "size, channels, channel_first, high, high_type, tx_mode",
85 | list(
86 | product(
87 | [
88 | (
89 | 224,
90 | 224,
91 | ),
92 | ],
93 | [
94 | (None,),
95 | (1,),
96 | (3,),
97 | (4,),
98 | ],
99 | [False],
100 | [1.0, 256],
101 | [
102 | np.uint8,
103 | # np.int, np.int8, np.int16, np.int32, np.int64 # This does not work
104 | ],
105 | ["CASCADE", "CONSUME"],
106 | )
107 | ),
108 | )
109 | def test_tx_on_ndarrays(
110 | size: t.Tuple[int],
111 | channels: t.Tuple[int],
112 | channel_first: bool,
113 | high: t.Union,
114 | high_type: t.Any,
115 | tx_mode: str,
116 | ) -> None:
117 | if channels == (None,):
118 | channels = ()
119 |
120 | if channel_first:
121 | size = channels + size
122 | else:
123 | size = size + channels
124 |
125 | tx = ptx.ToPILImage()
126 |
127 | if high == 1.0:
128 | img = np.random.uniform(low=0, high=1.0, size=size).astype(high_type)
129 | else:
130 | img = np.random.randint(low=0, high=256, size=size).astype(np.uint8)
131 |
132 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
133 |
134 | aug_1, params_1 = tx(img, orig_params)
135 |
136 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
137 |
138 | aug_3, params_3 = tx.consume_transform(img, orig_params)
139 |
140 | assert orig_params == params_1
141 | assert orig_params == params_2
142 | assert orig_params == params_3
143 |
144 | assert not ImageChops.difference(aug_1, aug_2).getbbox()
145 | assert not ImageChops.difference(aug_1, aug_3).getbbox()
146 |
147 |
148 | @pytest.mark.parametrize(
149 | "size, channels, channel_first, high, high_type, tx_mode",
150 | list(
151 | product(
152 | [
153 | (
154 | 224,
155 | 224,
156 | ),
157 | ],
158 | [
159 | (None,),
160 | (1,),
161 | (3,),
162 | (4,),
163 | ],
164 | [True],
165 | [1.0, 256],
166 | [
167 | np.uint8,
168 | # np.int, np.int8, np.int16, np.int32, np.int64 # This does not work
169 | ],
170 | ["CASCADE", "CONSUME"],
171 | )
172 | ),
173 | )
174 | def test_tx_on_torch_tensors(
175 | size: t.Tuple[int],
176 | channels: t.Tuple[int],
177 | channel_first: bool,
178 | high: t.Union,
179 | high_type: t.Any,
180 | tx_mode: str,
181 | ) -> None:
182 | if channels == (None,):
183 | channels = ()
184 |
185 | if channel_first:
186 | size = channels + size
187 | else:
188 | size = size + channels
189 |
190 | tx = ptx.ToPILImage()
191 |
192 | if high == 1.0:
193 | img = torch.from_numpy(
194 | np.random.uniform(low=0, high=1.0, size=size).astype(high_type)
195 | )
196 | else:
197 | img = torch.from_numpy(
198 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
199 | )
200 |
201 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
202 |
203 | aug_1, params_1 = tx(img, orig_params)
204 |
205 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
206 |
207 | aug_3, params_3 = tx.consume_transform(img, orig_params)
208 |
209 | assert orig_params == params_1
210 | assert orig_params == params_2
211 | assert orig_params == params_3
212 |
213 | assert not ImageChops.difference(aug_1, aug_2).getbbox()
214 | assert not ImageChops.difference(aug_1, aug_3).getbbox()
215 |
216 |
217 | # Main.
218 | if __name__ == "__main__":
219 | pass
220 |
--------------------------------------------------------------------------------
/tests/test_atomic_transforms/test_ToTensor.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | # Dependencies.
18 | import pytest
19 | import torch
20 |
21 | import parameterized_transforms.transforms as ptx
22 |
23 | import numpy as np
24 |
25 | import PIL.Image as Image
26 |
27 | from itertools import product
28 |
29 | import typing as t
30 |
31 |
32 | @pytest.mark.parametrize(
33 | "size, channels, tx_mode",
34 | list(
35 | product(
36 | [
37 | (
38 | 224,
39 | 224,
40 | ),
41 | ],
42 | [
43 | (None,),
44 | (3,),
45 | (4,),
46 | ],
47 | ["CASCADE", "CONSUME"],
48 | )
49 | ),
50 | )
51 | def test_tx_on_PIL_images_1(
52 | size: t.Tuple[int],
53 | channels: t.Tuple[int],
54 | tx_mode: str,
55 | ) -> None:
56 | if channels == (None,):
57 | channels = ()
58 |
59 | size = size + channels
60 |
61 | tx = ptx.ToTensor()
62 |
63 | img = Image.fromarray(
64 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
65 | )
66 |
67 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
68 |
69 | aug_1, params_1 = tx(img, orig_params)
70 |
71 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
72 |
73 | aug_3, params_3 = tx.consume_transform(img, orig_params)
74 |
75 | assert len(list(aug_1.shape)) == 3 # [C, H, W] format ensured
76 | assert aug_1.shape[0] == 1 or aug_1.shape[0] == 3 or aug_1.shape[0] == 4
77 | assert 0 <= torch.min(aug_1).item() <= 1.0
78 | assert 0 <= torch.max(aug_1).item() <= 1.0
79 |
80 | assert orig_params == params_1
81 | assert orig_params == params_2
82 | assert orig_params == params_3
83 |
84 | assert torch.all(torch.eq(aug_1, aug_2))
85 | assert torch.all(torch.eq(aug_1, aug_3))
86 |
87 |
88 | @pytest.mark.parametrize(
89 | "size, channels, tx_mode",
90 | list(
91 | product(
92 | [
93 | (
94 | 224,
95 | 224,
96 | ),
97 | ],
98 | [
99 | (1,),
100 | ],
101 | ["CASCADE", "CONSUME"],
102 | )
103 | ),
104 | )
105 | def test_tx_on_PIL_images_2(
106 | size: t.Tuple[int],
107 | channels: t.Tuple[int],
108 | tx_mode: str,
109 | ) -> None:
110 | with pytest.raises(TypeError):
111 | if channels == (None,):
112 | channels = ()
113 |
114 | size = size + channels
115 |
116 | tx = ptx.ToTensor()
117 |
118 | img = Image.fromarray(
119 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
120 | )
121 |
122 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
123 |
124 | aug_1, params_1 = tx(img, orig_params)
125 |
126 |
127 | @pytest.mark.parametrize(
128 | "size, channels, channel_first, high, high_type, tx_mode",
129 | list(
130 | product(
131 | [
132 | (
133 | 224,
134 | 224,
135 | ),
136 | ],
137 | [
138 | (None,),
139 | (1,),
140 | (3,),
141 | (4,),
142 | ],
143 | [True, False],
144 | [1.0, 256],
145 | [
146 | np.float16,
147 | np.float32,
148 | np.float64,
149 | np.uint8,
150 | np.int8,
151 | np.int16,
152 | np.int32,
153 | np.int64,
154 | ],
155 | ["CASCADE", "CONSUME"],
156 | )
157 | ),
158 | )
159 | def test_tx_on_ndarrays(
160 | size: t.Tuple[int],
161 | channels: t.Tuple[int],
162 | channel_first: bool,
163 | high: t.Union,
164 | high_type: t.Any,
165 | tx_mode: str,
166 | ) -> None:
167 | if channels == (None,):
168 | channels = ()
169 |
170 | if channel_first:
171 | size = channels + size
172 | else:
173 | size = size + channels
174 |
175 | tx = ptx.ToTensor()
176 |
177 | if high == 1.0:
178 | img = np.random.uniform(low=0, high=1.0, size=size).astype(high_type)
179 | else:
180 | img = np.random.randint(low=0, high=256, size=size).astype(np.uint8)
181 |
182 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
183 |
184 | aug_1, params_1 = tx(img, orig_params)
185 |
186 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
187 |
188 | aug_3, params_3 = tx.consume_transform(img, orig_params)
189 |
190 | assert len(list(aug_1.shape)) == 3 # [C, H, W] format ensured
191 | assert 0 <= torch.min(aug_1).item() <= high
192 | assert 0 <= torch.max(aug_1).item() <= high
193 |
194 | assert orig_params == params_1
195 | assert orig_params == params_2
196 | assert orig_params == params_3
197 |
198 | assert torch.all(torch.eq(aug_1, aug_2))
199 | assert torch.all(torch.eq(aug_1, aug_3))
200 |
201 |
202 | @pytest.mark.parametrize(
203 | "size, channels, channel_first, high, high_type, tx_mode",
204 | list(
205 | product(
206 | [
207 | (
208 | 224,
209 | 224,
210 | ),
211 | ],
212 | [
213 | (None,),
214 | (1,),
215 | (3,),
216 | (4,),
217 | ],
218 | [True, False],
219 | [1.0, 256],
220 | [
221 | np.float16,
222 | np.float32,
223 | np.float64,
224 | np.uint8,
225 | np.int8,
226 | np.int16,
227 | np.int32,
228 | np.int64,
229 | ],
230 | ["CASCADE", "CONSUME"],
231 | )
232 | ),
233 | )
234 | def test_tx_on_torch_tensors(
235 | size: t.Tuple[int],
236 | channels: t.Tuple[int],
237 | channel_first: bool,
238 | high: t.Union,
239 | high_type: t.Any,
240 | tx_mode: str,
241 | ) -> None:
242 | with pytest.raises(TypeError):
243 |
244 | if channels == (None,):
245 | channels = ()
246 |
247 | if channel_first:
248 | size = channels + size
249 | else:
250 | size = size + channels
251 |
252 | tx = ptx.ToTensor()
253 |
254 | if high == 1.0:
255 | img = torch.from_numpy(
256 | np.random.uniform(low=0, high=1.0, size=size).astype(high_type)
257 | )
258 | if high == 256:
259 | img = torch.from_numpy(
260 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
261 | )
262 |
263 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
264 |
265 | aug_1, params_1 = tx(img, orig_params)
266 |
267 |
268 | # Main.
269 | if __name__ == "__main__":
270 | pass
271 |
--------------------------------------------------------------------------------
/tests/test_composing_transforms/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 |
20 |
21 | # Main.
22 | if __name__ == "__main__":
23 | pass
24 |
--------------------------------------------------------------------------------
/tests/test_composing_transforms/test_RandomChoice.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 |
25 | import numpy as np
26 |
27 | import PIL.Image as Image
28 | import PIL.ImageChops as ImageChops
29 |
30 | from itertools import product
31 |
32 | import typing as t
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "size, p, repeat_count, core_txs, channels, tx_mode",
37 | list(
38 | product(
39 | [
40 | (
41 | 224,
42 | 224,
43 | ),
44 | ],
45 | # p
46 | [None, "random", "equal"],
47 | # repeat count to try out a large number of orderings
48 | [idx for idx in range(5)],
49 | [
50 | [
51 | ptx.RandomHorizontalFlip(p=0.5),
52 | ptx.ColorJitter(
53 | brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1
54 | ),
55 | ptx.RandomGrayscale(p=0.5),
56 | ],
57 | [
58 | ptx.RandomHorizontalFlip(p=0.5),
59 | ptx.ColorJitter(
60 | brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2
61 | ),
62 | ptx.RandomGrayscale(p=0.5),
63 | ],
64 | ],
65 | [
66 | (3,),
67 | # (4,),
68 | (None,),
69 | ],
70 | ["CASCADE", "CONSUME"],
71 | )
72 | ),
73 | )
74 | def test_tx_on_PIL_images(
75 | size: t.Tuple[int],
76 | p: t.Any,
77 | repeat_count: int,
78 | core_txs: t.List[t.Callable],
79 | channels: t.Tuple[int],
80 | tx_mode: str,
81 | ) -> None:
82 |
83 | # Clean probabilities.
84 | if p == "random":
85 | p = np.random.uniform(
86 | low=0,
87 | high=1.0,
88 | size=[
89 | len(core_txs),
90 | ],
91 | )
92 | p /= np.sum(p)
93 | elif p == "equal":
94 | p = np.ones(
95 | shape=[
96 | len(core_txs),
97 | ]
98 | )
99 | p /= np.sum(p)
100 | elif p is None:
101 | pass
102 | else:
103 | raise NotImplementedError(
104 | "ERROR | Ensure `p` is `None` or `equal` or `random`."
105 | )
106 |
107 | if channels == (None,):
108 | channels = ()
109 |
110 | size = size + channels
111 |
112 | tx = ptx.RandomChoice(transforms=core_txs, p=p, tx_mode=tx_mode)
113 |
114 | img = Image.fromarray(
115 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
116 | )
117 |
118 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
119 |
120 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
121 | assert len(params_1) - len(orig_params) == tx.param_count
122 |
123 | orig_params = ()
124 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
125 | aug_3, params_3 = tx.consume_transform(img, params_2)
126 |
127 | assert orig_params == params_3
128 | if isinstance(aug_2, torch.Tensor):
129 | assert torch.all(torch.eq(aug_2, aug_3))
130 | else:
131 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
132 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
133 |
134 | id_params = tx.get_default_params(img=img, processed=True)
135 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
136 | assert id_aug_rem_params == ()
137 | assert not ImageChops.difference(id_aug_img, img).getbbox()
138 |
139 |
140 | @pytest.mark.parametrize(
141 | "size, p, repeat_count, core_txs, channels, tx_mode",
142 | list(
143 | product(
144 | [
145 | (
146 | 224,
147 | 224,
148 | ),
149 | ],
150 | # p
151 | [None, "random", "equal"],
152 | # repeat count to try out a large number of orderings
153 | [idx for idx in range(5)],
154 | [
155 | [
156 | ptx.RandomHorizontalFlip(p=0.5),
157 | ptx.ColorJitter(
158 | brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1
159 | ),
160 | ptx.RandomGrayscale(p=0.5),
161 | ],
162 | [
163 | ptx.RandomHorizontalFlip(p=0.5),
164 | ptx.ColorJitter(
165 | brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2
166 | ),
167 | ptx.RandomGrayscale(p=0.5),
168 | ],
169 | ],
170 | [
171 | (3,),
172 | ],
173 | ["CASCADE", "CONSUME"],
174 | )
175 | ),
176 | )
177 | def test_tx_on_torch_tensors(
178 | size: t.Tuple[int],
179 | repeat_count: int,
180 | p: t.Any,
181 | core_txs: t.List[t.Callable],
182 | channels: t.Tuple[int],
183 | tx_mode: str,
184 | ) -> None:
185 |
186 | # Clean probabilities.
187 | if p == "random":
188 | p = np.random.uniform(
189 | low=0,
190 | high=1.0,
191 | size=[
192 | len(core_txs),
193 | ],
194 | )
195 | p /= np.sum(p)
196 | elif p == "equal":
197 | p = np.ones(
198 | shape=[
199 | len(core_txs),
200 | ]
201 | )
202 | p /= np.sum(p)
203 | elif p is None:
204 | pass
205 | else:
206 | raise NotImplementedError(
207 | "ERROR | Ensure `p` is `None` or `equal` or `random`."
208 | )
209 |
210 | if channels == (None,):
211 | channels = ()
212 |
213 | size = channels + size
214 |
215 | tx = ptx.RandomChoice(transforms=core_txs, p=p, tx_mode=tx_mode)
216 |
217 | img = torch.from_numpy(
218 | np.random.uniform(low=0.0, high=1.0, size=size).astype(np.uint8)
219 | )
220 |
221 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
222 |
223 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
224 | assert len(params_1) - len(orig_params) == tx.param_count
225 |
226 | orig_params = ()
227 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
228 | aug_3, params_3 = tx.consume_transform(img, params_2)
229 |
230 | assert orig_params == params_3
231 | if isinstance(aug_2, torch.Tensor):
232 | assert torch.all(torch.eq(aug_2, aug_3))
233 | else:
234 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
235 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
236 |
237 | id_params = tx.get_default_params(img=img, processed=True)
238 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
239 | assert id_aug_rem_params == ()
240 | assert torch.all(torch.eq(id_aug_img, img))
241 |
242 |
243 | # Main.
244 | if __name__ == "__main__":
245 | pass
246 |
--------------------------------------------------------------------------------
/tests/test_composing_transforms/test_RandomOrder.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | # Dependencies.
19 | import pytest
20 |
21 | import torch
22 |
23 | import parameterized_transforms.transforms as ptx
24 |
25 | import numpy as np
26 |
27 | import PIL.Image as Image
28 | import PIL.ImageChops as ImageChops
29 |
30 | from itertools import product
31 |
32 | import typing as t
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "size, repeat_count, core_txs, channels, tx_mode",
37 | list(
38 | product(
39 | [
40 | (
41 | 224,
42 | 224,
43 | ),
44 | ],
45 | # repeat count to try out a large number of orderings
46 | [idx for idx in range(5)],
47 | [
48 | [
49 | ptx.RandomHorizontalFlip(p=0.5),
50 | ptx.ColorJitter(
51 | brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1
52 | ),
53 | ptx.RandomGrayscale(p=0.5),
54 | ],
55 | [
56 | ptx.RandomHorizontalFlip(p=0.5),
57 | ptx.ColorJitter(
58 | brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2
59 | ),
60 | ptx.RandomGrayscale(p=0.5),
61 | ],
62 | ],
63 | [
64 | (3,),
65 | # (4,),
66 | (None,),
67 | ],
68 | ["CASCADE", "CONSUME"],
69 | )
70 | ),
71 | )
72 | def test_tx_on_PIL_images(
73 | size: t.Tuple[int],
74 | repeat_count: int,
75 | core_txs: t.List[t.Callable],
76 | channels: t.Tuple[int],
77 | tx_mode: str,
78 | ) -> None:
79 |
80 | if channels == (None,):
81 | channels = ()
82 |
83 | size = size + channels
84 |
85 | tx = ptx.RandomOrder(transforms=core_txs, tx_mode=tx_mode)
86 |
87 | img = Image.fromarray(
88 | np.random.randint(low=0, high=256, size=size).astype(np.uint8)
89 | )
90 |
91 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
92 |
93 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
94 | assert len(params_1) - len(orig_params) == tx.param_count
95 |
96 | orig_params = ()
97 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
98 | aug_3, params_3 = tx.consume_transform(img, params_2)
99 |
100 | assert orig_params == params_3
101 | if isinstance(aug_2, torch.Tensor):
102 | assert torch.all(torch.eq(aug_2, aug_3))
103 | else:
104 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
105 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
106 |
107 | id_params = tx.get_default_params(img=img, processed=True)
108 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
109 | assert id_aug_rem_params == ()
110 | assert not ImageChops.difference(id_aug_img, img).getbbox()
111 |
112 |
113 | @pytest.mark.parametrize(
114 | "size, repeat_count, core_txs, channels, tx_mode",
115 | list(
116 | product(
117 | [
118 | (
119 | 224,
120 | 224,
121 | ),
122 | ],
123 | # repeat count to try out a large number of orderings
124 | [idx for idx in range(5)],
125 | [
126 | [
127 | ptx.RandomHorizontalFlip(p=0.5),
128 | ptx.ColorJitter(
129 | brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1
130 | ),
131 | ptx.RandomGrayscale(p=0.5),
132 | ],
133 | [
134 | ptx.RandomHorizontalFlip(p=0.5),
135 | ptx.ColorJitter(
136 | brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2
137 | ),
138 | ptx.RandomGrayscale(p=0.5),
139 | ],
140 | ],
141 | [
142 | (3,),
143 | ],
144 | ["CASCADE", "CONSUME"],
145 | )
146 | ),
147 | )
148 | def test_tx_on_torch_tensors(
149 | size: t.Tuple[int],
150 | repeat_count: int,
151 | core_txs: t.List[t.Callable],
152 | channels: t.Tuple[int],
153 | tx_mode: str,
154 | ) -> None:
155 |
156 | if channels == (None,):
157 | channels = ()
158 |
159 | size = channels + size
160 |
161 | tx = ptx.RandomOrder(transforms=core_txs, tx_mode=tx_mode)
162 |
163 | img = torch.from_numpy(
164 | np.random.uniform(low=0.0, high=1.0, size=size).astype(np.uint8)
165 | )
166 |
167 | orig_params = tuple([3.0, 4.0, -1.0, 0.0])
168 |
169 | aug_1, params_1 = tx.cascade_transform(img, orig_params)
170 | assert len(params_1) - len(orig_params) == tx.param_count
171 |
172 | orig_params = ()
173 | aug_2, params_2 = tx.cascade_transform(img, orig_params)
174 | aug_3, params_3 = tx.consume_transform(img, params_2)
175 |
176 | assert orig_params == params_3
177 | if isinstance(aug_2, torch.Tensor):
178 | assert torch.all(torch.eq(aug_2, aug_3))
179 | else:
180 | assert not ImageChops.difference(aug_2, aug_3).getbbox()
181 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2])
182 |
183 | id_params = tx.get_default_params(img=img, processed=True)
184 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params)
185 | assert id_aug_rem_params == ()
186 | assert torch.all(torch.eq(id_aug_img, img))
187 |
188 |
189 | # Main.
190 | if __name__ == "__main__":
191 | pass
192 |
--------------------------------------------------------------------------------
/tests/test_functions.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Apple Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | import parameterized_transforms.core as ptc
19 | import parameterized_transforms.utils as ptu
20 |
21 | import numpy as np
22 |
23 |
24 | def test_get_total_params_count() -> None:
25 | class Object:
26 | def __init__(self, param_count: int) -> None:
27 | self.param_count = param_count
28 |
29 | def __str__(self) -> str:
30 | return self.__repr__()
31 |
32 | def __repr__(self) -> str:
33 | return f"{self.__class__.__name__}" f"(param_count={self.param_count})"
34 |
35 | get_total_params_count = ptu.get_total_params_count
36 |
37 | for NUM_OBJ in [10, 100]:
38 |
39 | obj_list = [
40 | Object(param_count=np.random.randint(0, 10)) for _ in range(NUM_OBJ)
41 | ]
42 | # print(obj_list)
43 |
44 | obj_tuple = tuple(
45 | Object(param_count=np.random.randint(0, 10)) for _ in range(NUM_OBJ)
46 | )
47 | # print(obj_tuple)
48 |
49 | obj_dict = dict(
50 | [("", Object(param_count=np.random.randint(0, 10))) for _ in range(NUM_OBJ)]
51 | )
52 | # print(obj_dict)
53 |
54 | assert get_total_params_count(obj_list) == sum(
55 | [obj.param_count for obj in obj_list]
56 | )
57 | assert get_total_params_count(obj_tuple) == sum(
58 | [obj.param_count for obj in obj_tuple]
59 | )
60 | assert get_total_params_count(obj_dict) == sum(
61 | [obj.param_count for obj_name, obj in obj_dict.items()]
62 | )
63 |
64 |
65 | def test_concat_params() -> None:
66 |
67 | concat_params = ptc.Transform.concat_params
68 |
69 | obj_tuple = [
70 | (),
71 | (
72 | 1,
73 | 2,
74 | 3,
75 | ),
76 | (4,),
77 | (),
78 | (),
79 | (5, 6),
80 | (
81 | 7,
82 | 8,
83 | ),
84 | ]
85 | assert concat_params(*obj_tuple) == (1, 2, 3, 4, 5, 6, 7, 8)
86 |
87 | params_1, params_2, params_3, params_4 = (
88 | (),
89 | (1, 2),
90 | (
91 | 3,
92 | 4,
93 | 5,
94 | ),
95 | (),
96 | )
97 | assert concat_params(params_1, params_2, params_3, params_4) == (1, 2, 3, 4, 5)
98 |
99 |
100 | # Main.
101 | if __name__ == "__main__":
102 | pass
103 |
--------------------------------------------------------------------------------