├── .github └── workflows │ ├── sphinx.yml │ └── testing.yml ├── .gitignore ├── ACKNOWLEDGEMENTS ├── CITATION.cff ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── Makefile └── source │ ├── conf.py │ ├── index.rst │ ├── python │ ├── core.rst │ ├── transforms.rst │ ├── utils.rst │ └── wrappers.rst │ └── tutorials │ ├── 000-About-the-Package.md │ ├── 001-The-Structure-of-Parametrized-Transforms.md │ ├── 002-How-to-Write-Your-Own-Transforms.md │ ├── 003-A-Brief-Introduction-to-the-Transforms-in-This-Package.md │ ├── 004-Parametrized-Transforms-in-Action.md │ ├── 005-Migrate-To-and-From-torch-in-Three-Easy-Steps.md │ ├── 999-In-a-Nutshell.md │ └── assets │ ├── 004-crazytx-cat-aug-1.jpeg │ ├── 004-crazytx-cat-aug-2.jpeg │ ├── 004-crazytx-cat-orig.jpeg │ ├── 004-pnrce-cat-aug-1.jpeg │ ├── 004-pnrce-cat-aug-2.jpeg │ ├── 004-pnrce-cat-orig.jpeg │ ├── cat-aug-0.jpeg │ ├── cat-aug-1.jpeg │ ├── cat-aug-2.jpeg │ ├── cat-aug-3.jpeg │ ├── cat-aug-4.jpeg │ ├── cat-aug-5.jpeg │ ├── cat.jpeg │ ├── dog-aug-0.jpeg │ ├── dog-aug-1.jpeg │ ├── dog-aug-2.jpeg │ ├── dog-aug-3.jpeg │ ├── dog-aug-4.jpeg │ ├── dog-aug-5.jpeg │ ├── dog-aug-6.jpeg │ ├── dog-aug-7.jpeg │ ├── dog.jpeg │ ├── no-dog-aug-0.jpeg │ ├── no-dog-aug-1.jpeg │ ├── no-dog-aug-2.jpeg │ └── no-dog-aug-3.jpeg ├── examples ├── 002-001-RandomColorErasing.ipynb ├── 002-002-RandomSubsetApply.ipynb ├── 002-003-Visualizations.ipynb ├── 004-001-Playground.ipynb └── 005-001-Torch-Torchvision-to-Parametrized-Transforms.ipynb ├── parameterized_transforms ├── __init__.py ├── core.py ├── transforms.py ├── utils.py └── wrappers.py ├── pyproject.toml ├── requirements.txt └── tests ├── test_atomic_transforms ├── __init__.py ├── test_CenterCrop.py ├── test_ColorJitter.py ├── test_ConvertImageDtype.py ├── test_ElasticTransform.py ├── test_FiveCrop.py ├── test_GaussianBlur.py ├── test_Grayscale.py ├── test_Lambda.py ├── test_LinearTransformation.py ├── test_Normalize.py ├── test_PILToTensor.py ├── test_Pad.py ├── test_RandomAdjustSharpness.py ├── test_RandomAffine.py ├── test_RandomAutocontrast.py ├── test_RandomCrop.py ├── test_RandomEqualize.py ├── test_RandomErasing.py ├── test_RandomGrayscale.py ├── test_RandomHorizontalFlip.py ├── test_RandomInvert.py ├── test_RandomPerspective.py ├── test_RandomPosterize.py ├── test_RandomResizedCrop.py ├── test_RandomRotation.py ├── test_RandomSolarize.py ├── test_RandomVerticalFlip.py ├── test_Resize.py ├── test_TenCrop.py ├── test_ToPILImage.py └── test_ToTensor.py ├── test_composing_transforms ├── __init__.py ├── test_Compose.py ├── test_RandomApply.py ├── test_RandomChoice.py └── test_RandomOrder.py └── test_functions.py /.github/workflows/sphinx.yml: -------------------------------------------------------------------------------- 1 | name: Deploy sphinx site to Pages 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | workflow_dispatch: 7 | 8 | permissions: 9 | contents: read 10 | pages: write 11 | id-token: write 12 | 13 | concurrency: 14 | group: "pages" 15 | cancel-in-progress: false 16 | 17 | jobs: 18 | build: 19 | runs-on: ubuntu-latest 20 | permissions: 21 | contents: write 22 | steps: 23 | - uses: actions/checkout@v4 24 | - uses: actions/setup-python@v5 25 | with: 26 | python-version: '3.9' 27 | - name: Build with sphinx 28 | run: | 29 | pip install -e . 30 | cd docs 31 | pip install sphinx sphinx-book-theme myst-parser 32 | make html 33 | - name: Upload artifact 34 | uses: actions/upload-pages-artifact@v3 35 | with: 36 | path: ./docs/build/html 37 | 38 | deploy: 39 | environment: 40 | name: github-pages 41 | url: ${{ steps.deployment.outputs.page_url }} 42 | runs-on: ubuntu-latest 43 | needs: build 44 | steps: 45 | - name: Deploy to GitHub Pages 46 | id: deployment 47 | uses: actions/deploy-pages@v4 -------------------------------------------------------------------------------- /.github/workflows/testing.yml: -------------------------------------------------------------------------------- 1 | name: Run tests upon pull request events 2 | 3 | on: 4 | pull_request: 5 | branches: ["main"] 6 | 7 | jobs: 8 | pre-merge: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - uses: actions/setup-python@v5 13 | with: 14 | python-version: '3.9' 15 | - name: Build package and test dependencies to run unittests suite using pytest 16 | run: | 17 | pip install -e . 18 | pip install -e '.[test]' 19 | pytest -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # IDE related 132 | .idea 133 | 134 | # Testing related 135 | data 136 | .DS_Store 137 | -------------------------------------------------------------------------------- /ACKNOWLEDGEMENTS: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this Parameterized Transforms Software may utilize the following copyrighted 3 | material, the use of which is hereby acknowledged. 4 | 5 | Note: 6 | The use of code from the below mentioned repository(ies) is to ensure performance parity. 7 | 8 | _____________________ 9 | 10 | Soumith Chintala (pytorch/vision) 11 | 12 | BSD 3-Clause License 13 | 14 | Copyright (c) Soumith Chintala 2016, 15 | All rights reserved. 16 | 17 | Redistribution and use in source and binary forms, with or without 18 | modification, are permitted provided that the following conditions are met: 19 | 20 | * Redistributions of source code must retain the above copyright notice, this 21 | list of conditions and the following disclaimer. 22 | 23 | * Redistributions in binary form must reproduce the above copyright notice, 24 | this list of conditions and the following disclaimer in the documentation 25 | and/or other materials provided with the distribution. 26 | 27 | * Neither the name of the copyright holder nor the names of its 28 | contributors may be used to endorse or promote products derived from 29 | this software without specific prior written permission. 30 | 31 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 32 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 33 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 34 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 35 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 36 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 37 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 38 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 39 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 40 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Dhekane" 5 | given-names: "Eeshan Gunesh" 6 | orcid: "https://orcid.org/0009-0006-3026-6258" 7 | title: "Parameterized Transforms" 8 | version: 1.0.0 9 | date-released: 2025-02-15 10 | url: "https://github.com/apple/parameterized-transforms" 11 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. In this project, there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Parameterized Transforms 2 | 3 | 4 | ## Index 5 | 1. [About the Package](#about-the-package) 6 | 2. [Installation](#installation) 7 | 3. [Getting Started](#getting-started) 8 | 9 | 10 | 11 | ## About the Package 12 | * The package provides a uniform, modular, and easily extendable implementation of `torchvision`-based transforms that provides access to their parameterization. 13 | * With this access, the transforms enable users to achieve the following two important functionalities-- 14 | * Given an image, the transform can return an augmentation along with the parameters used for the augmentation. 15 | * Given an image and augmentation parameters, the transform can return the corresponding augmentation. 16 | 17 | 18 | 19 | ## Installation 20 | - To install the package directly, run the following commands: 21 | ``` 22 | git clone https://github.com/apple/parameterized-transforms 23 | cd parameterized-transforms 24 | pip install -e . 25 | ``` 26 | - To install the package via `pip`, run the following command: 27 | ``` 28 | pip install --upgrade https://github.com/apple/parameterized-transforms 29 | ``` 30 | - If you want to run unit tests locally, run the following steps: 31 | ``` 32 | git clone https://github.com/apple/parameterized-transforms 33 | cd parameterized-transforms 34 | pip install -e . 35 | pip install -e '.[test]' 36 | pytest 37 | ``` 38 | 39 | 40 | 41 | ## Getting Started 42 | * To understand the structure of parameterized transforms and the details of the package, we recommend the reader to 43 | start with 44 | [The First Tutorial](https://apple.github.io/parameterized-transforms/tutorials/000-About-the-Package.html) 45 | of our 46 | [Tutorial Series](https://apple.github.io/parameterized-transforms/). 47 | * However, for a quick starter, check out [Parameterized Transforms in a Nutshell](https://apple.github.io/parameterized-transforms/tutorials/999-In-a-Nutshell.html). 48 | 49 | --- 50 | 51 | ## Acknowledgement 52 | In its development, this project received help from multiple researchers, engineers, and other contributors from Apple. 53 | Special thanks to: Tim Kolecke, Jason Ramapuram, Russ Webb, David Koski, Mike Drob, Megan Maher Welsh, Marco Cuturi Cameto, 54 | Dan Busbridge, Xavier Suau Cuadros, and Miguel Sarabia del Castillo. 55 | 56 | ## Citation 57 | If you find this package useful and want to cite our work, here is the citation: 58 | ``` 59 | @software{Dhekane_Parameterized_Transforms_2025, 60 | author = {Dhekane, Eeshan Gunesh}, 61 | month = {2}, 62 | title = {{Parameterized Transforms}}, 63 | url = {https://github.com/apple/parameterized-transforms}, 64 | version = {1.0.0}, 65 | year = {2025} 66 | } 67 | ``` 68 | 69 | --- 70 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | project = 'Parameterized Transforms' 10 | copyright = '2025, Apple Inc' 11 | author = 'Eeshan Gunesh Dhekane' 12 | 13 | # -- General configuration --------------------------------------------------- 14 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 15 | 16 | extensions = [ 17 | 'myst_parser', 18 | 'sphinx.ext.autodoc', 19 | 'sphinx.ext.autosummary', 20 | 'sphinx.ext.intersphinx', 21 | 'sphinx.ext.githubpages', 22 | 'sphinx.ext.napoleon', 23 | ] 24 | 25 | myst_enable_extensions = ['html_image'] 26 | source_suffix = ['.rst', '.md'] 27 | autoclass_content = 'both' 28 | templates_path = ['_templates'] 29 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 30 | 31 | 32 | 33 | # -- Options for HTML output ------------------------------------------------- 34 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 35 | 36 | html_theme = 'sphinx_book_theme' 37 | html_title = 'Parameterized Transforms' 38 | html_theme_options = { 39 | 'repository_url': 'https://github.com/apple/parameterized-transforms', 40 | 'use_repository_button': True, 41 | 'navigation_with_keys': False, 42 | } 43 | html_static_path = ['_static'] 44 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Parameterized Transforms 2 | ======================== 3 | 4 | 5 | About the Package 6 | ----------------- 7 | 8 | * The package provides a uniform, modular, and easily extendable implementation of `torchvision`-based transforms that provides access to their parameterization. 9 | 10 | * With this access, the transforms enable users to achieve the following two important functionalities-- 11 | 12 | * Given an image, return an augmentation along with the parameters used for the augmentation. 13 | 14 | * Given an image and augmentation parameters, return the corresponding augmentation. 15 | 16 | 17 | Installation 18 | ------------ 19 | * To install the package directly, run the following commands: 20 | 21 | .. code-block:: bash 22 | 23 | git clone git@github.com:apple/parameterized-transforms.git 24 | cd parameterized-transforms 25 | pip install -e . 26 | 27 | 28 | * To install the package via `pip`, run the following command: 29 | 30 | .. code-block:: bash 31 | 32 | pip install --upgrade git+https://git@github.com:apple/parameterized-transforms.git 33 | 34 | 35 | Get Started 36 | ----------- 37 | * To understand the structure of parameterized transforms and the details of the package, we recommend the reader to start with the :ref:`Tutorial Series `. 38 | 39 | * Otherwise, for a quick starter, check out :ref:`Parameterized Transforms in a Nutshell `. 40 | 41 | 42 | Important Links 43 | --------------- 44 | 45 | .. _Quick-Start-label: 46 | 47 | .. toctree:: 48 | :caption: Quick-Start 49 | :maxdepth: 1 50 | 51 | tutorials/999-In-a-Nutshell 52 | 53 | .. _Tutorial-Series-label: 54 | 55 | .. toctree:: 56 | :caption: Tutorial Series 57 | :maxdepth: 1 58 | 59 | tutorials/000-About-the-Package 60 | tutorials/001-The-Structure-of-Parametrized-Transforms 61 | tutorials/002-How-to-Write-Your-Own-Transforms 62 | tutorials/003-A-Brief-Introduction-to-the-Transforms-in-This-Package 63 | tutorials/004-Parametrized-Transforms-in-Action 64 | tutorials/005-Migrate-To-and-From-torch-in-Three-Easy-Steps 65 | 66 | 67 | .. toctree:: 68 | :caption: Python API Reference 69 | :maxdepth: 1 70 | 71 | python/core 72 | python/transforms 73 | python/utils 74 | python/wrappers 75 | -------------------------------------------------------------------------------- /docs/source/python/core.rst: -------------------------------------------------------------------------------- 1 | core 2 | ==== 3 | 4 | .. automodule:: parameterized_transforms.core 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/python/transforms.rst: -------------------------------------------------------------------------------- 1 | transforms 2 | ========== 3 | 4 | .. automodule:: parameterized_transforms.transforms 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/python/utils.rst: -------------------------------------------------------------------------------- 1 | utils 2 | ===== 3 | 4 | .. automodule:: parameterized_transforms.utils 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/python/wrappers.rst: -------------------------------------------------------------------------------- 1 | wrappers 2 | ======== 3 | 4 | .. automodule:: parameterized_transforms.wrappers 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/tutorials/000-About-the-Package.md: -------------------------------------------------------------------------------- 1 | # About the Package 2 | 3 | - 5 Minute Read 4 | 5 | 6 | ## Summary 7 | * This tutorial describes the three important aspects of the `Parameterized Transforms` package: 8 | 1. **Why** do we need this package?, 9 | 2. **What** does the package provide?, and 10 | 3. **How** to use the package? 11 | 12 | 13 | * **NOTE:** We will be using the terms **Augmentation (noun) / Augment (verb)** interchangeably with 14 | **Transform (noun) / Transform (verb)** throughout the tutorials. 15 | 16 | 17 | ## The *Why* Aspect 18 | * Augmentation strategies are important in computer vision research for improving the performance of deep learning approaches. 19 | * Popular libraries like `torchvision` and `kornia` provide implementations of widely used and important transforms. 20 | * Many recent research ideas revolve around using the information of augmentation parameters in order to learn better representations. 21 | In this context, different popular libraries have different pros and cons: 22 | * For instance, most of recent deep learning approaches define their augmentation stacks in terms of the 23 | `torchvision`-based transforms, experiment with them, and report the best-performing stacks. 24 | However, `torchvision`-based transforms do NOT provide access to their parameters, thereby limiting the research possibilities aimed at extracting information provided by augmentation parameters to learn better data representations. 25 | * On the other hand, although `kornia`-based augmentation stacks do provide access to the parameters of the augmentations, 26 | reproducing results obtained with `torchvision` stacks using `kornia`-based augmentations is difficult due to the differences in their implementation. 27 | * Ideally, we want to have transforms implementations that have the following desired properties: 28 | 1. they can provide access to their parameters by exposing them, 29 | 2. they allow reproducible augmentations by enabling application of the transform defined by given parameters, 30 | 3. they are easy to subclass and extend in order to tweak their functionality, and 31 | 4. they have implementations that match those of the transforms used in obtaining the state-of-the-art results (mostly, `torchvision`). 32 | * This is very difficult to achieve with any of the currently existing libraries. 33 | 34 | 35 | 36 | ## The *What* Aspect 37 | * What this package provides is a modular, uniform, and easily extendable skeleton with a re-implementation of `torchvision`-based 38 | transforms that gives you access to their augmentation parameters and allows reproducible augmentations. 39 | * In particular, these transforms can perform two extremely crucial tasks associated with exposing their parameters: 40 | 1. Given an image, the transform can return an augmentation along with the parameters used for the augmentation. 41 | 2. Given an image and well-defined augmentation parameters, the transform can return the corresponding augmented image. 42 | * The uniform template for all transforms and a modular re-implementation means that you can easily subclass the 43 | transforms and tweak their functionalities. 44 | * In addition, you can write your own custom transforms using the provided templates and combine them seamlessly 45 | with other custom or package-defined transforms for your experimentation. 46 | 47 | 48 | 49 | ## The *How* Aspect 50 | * To start using the package, we recommend the following-- 51 | 1. Read through the [Prerequisites](#prerequisites) listed below and be well-acquainted with them. 52 | 2. [Install the Package](https://github.com/apple/parameterized-transforms/blob/main/README.md#installation) as described in the link. 53 | 3. Read through the [Tutorial Series](#tutorials-in-a-nutshell). 54 | 4. After that, you should be ready to write and experiment with parameterized transforms! 55 | 56 | 57 | 58 | ## Prerequisites 59 | Here are the prerequisites for this package-- 60 | * `numpy`: being comfortable with `numpy` arrays and operations, 61 | * `PIL`: being comfortable with basic `PIL` operations and the `PIL.Image.Image` class, 62 | * `torch`: being comfortable with `torch` tensors and operations, and 63 | * `torchvision`: being comfortable with `torchvision` transforms and operations. 64 | 65 | 66 | 67 | ## A Preview of All Tutorials 68 | * Here is an overview of the tutorials in this series and the topics they cover-- 69 | 70 | | Title | Contents | 71 | |-------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 72 | | [0. About the Package](000-About-the-Package.md) | An overview of the package | 73 | | [1. The Structure of Parametrized Transforms](001-The-Structure-of-Parametrized-Transforms.md) | Explanation of the base classes `Transform`, `AtomicTransform`, and `ComposingTransform` | 74 | | [2. How to Write Your Own Transforms](002-How-to-Write-Your-Own-Transforms.md) | A walk-through of writing custom transforms-- an atomic transform named `RandomColorErasing` and a composing transform named `RandomSubsetApply` | 75 | | [3. A Brief Introduction to the Transforms in This Package](003-A-Brief-Introduction-to-the-Transforms-in-This-Package.md) | * Information about all the transforms provided in this package
* Demonstrations of some of the important transforms from the package | 76 | | [4. Parametrized Transforms in Action](004-Parametrized-Transforms-in-Action.md) | * Visualization of augmentations produced by the custom transforms
* Combining custom transforms with the ones from the package
* Extending transforms easily to tweak their behaviors | 77 | | [5. Migrate From `torch`/`torchvision` to `parameterized_transforms` in Three Easy Steps](005-Migrate-To-and-From-torch-in-Three-Easy-Steps.md) | Instructions to easily modify code with `torch`-based datasets/loaders and `torchvision`-based transforms to use the parameterized transforms | 78 | 79 | 80 | 81 | ## Credits 82 | In case you find our work useful in your research, you can use the following `bibtex` entry to cite us-- 83 | ```text 84 | @software{Dhekane_Parameterized_Transforms_2025, 85 | author = {Dhekane, Eeshan Gunesh}, 86 | month = {2}, 87 | title = {{Parameterized Transforms}}, 88 | url = {https://github.com/apple/parameterized-transforms}, 89 | version = {1.0.0}, 90 | year = {2025} 91 | } 92 | ``` 93 | 94 | 95 | 96 | ## About the Next Tutorial 97 | * In the next tutorial [001-The-Structure-of-Parametrized-Transforms.md](001-The-Structure-of-Parametrized-Transforms.md), we will describe the core structure of parameterized transforms. 98 | * We will see two different types of transforms, **Atomic** and **Composing**, and describe their details. 99 | -------------------------------------------------------------------------------- /docs/source/tutorials/999-In-a-Nutshell.md: -------------------------------------------------------------------------------- 1 | # Parameterized Transforms in a Nutshell 2 | 3 | 4 | 1. Parameterized transforms input an image and a tuple of parameters to produce augmentation and modified parameters. 5 | The mode of any parameterized transform can be `CASCADE` or `CONSUME`. 6 | In `CASCADE` mode, the transform generates an augmentation and appends parameters used in this augmentation to the input parameters. 7 | In `CONSUME` mode, the transform removes from the beginning of the tuple the parameters it needs and generates the corresponding augmentation. 8 | The augmentation and the modified parameters are then returned. 9 | With these two modes, we can have reproducible, flexible augmentations and much more. 10 | ```python 11 | # Signature of parameterized transforms. 12 | import parameterized_transforms.transforms as ptx 13 | import parameterized_transforms.core as ptc 14 | 15 | tx1 = ptx.RandomRotation(degrees=45) # Default mode: CASCADE. 16 | params1 = (3, 2, 0, 1, 0.3, -2.5) # Example: Parameters from previous parameterized transforms. 17 | 18 | augmentation1, modified_params = tx1(image, params1) 19 | # ALTERNATIVELY: augmentation1, modified_params = tx1.cascade_transform(image, params1) 20 | # augmentation: Image rotated by a random angle, say 31.25 degrees. 21 | # modified_params: (3, 2, 0, 1, 0.3, -2.5, 31.25). Note the appended 31.25 angle value. 22 | 23 | tx2 = ptx.RandomRotation(degrees=45, tx_mode=ptc.TransformMode.CONSUME) 24 | params2 = (31.25, 0, 1, 0.5, -1.7) # Example: Parameter for `RandomRotation` (31.25) and possibly other parameterized transforms. 25 | augmentation2, remaining_params = tx2(image, params2) 26 | # ALTERNATIVELY: tx1.consume_transform(image, params2) or tx2.consume_transform(image, params2) 27 | # augmentation2: The same augmentation as augmentation1 above. 28 | # remaining_params: (0, 1, 0.5, -1.7). 29 | ``` 30 | 31 | 2. Parameterized versions of all `torchvision`-based transforms are supported. 32 | ```python 33 | # Example 34 | import parameterized_transforms.transforms as ptx 35 | 36 | tx = ptx.Compose( 37 | [ 38 | ptx.RandomHorizontalFlip(p=0.5), 39 | ptx.RandomApply( 40 | [ 41 | ptx.ColorJitter( 42 | brightness=0.1, 43 | contrast=0.1, 44 | saturation=0.1, 45 | hue=0.1, 46 | ) 47 | ], 48 | p=0.5, 49 | ), 50 | ptx.ToTensor(), 51 | ptx.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 52 | ] 53 | ) 54 | 55 | augmentation, params = tx(image) # Default params: (). 56 | augmentation_1, empty_params = tx.consume_transform(image, params) 57 | # augmentation_1: Same as the augmentation above. 58 | # empty_params: (), as all params are extracted and used. 59 | ``` 60 | 61 | 3. You can [write your own transforms](https://apple.github.io/parameterized-transforms/tutorials/002-How-to-Write-Your-Own-Transforms.html) that adhere to the structure of parameterized transforms. 62 | Then, your transforms will work seamlessly with those from the package! 63 | ```python 64 | import parameterized_transforms.transforms as ptx 65 | 66 | tx = RandomSubsetApply( # Your custom transform 67 | [ 68 | RandomColorErasing(), # Your custom transform 69 | ptx.RandomHorizontalFlip(p=0.5), 70 | ptx.RandomApply( 71 | [ 72 | ptx.ColorJitter( 73 | brightness=0.1, 74 | contrast=0.1, 75 | saturation=0.1, 76 | hue=0.1, 77 | ) 78 | ], 79 | p=0.5, 80 | ) 81 | ] 82 | ) 83 | 84 | augmentation, params = tx(image) # Default params: (). 85 | augmentation_1, empty_params = tx.consume_transform(image, params) 86 | # augmentation_1: Same as the augmentation above. 87 | # empty_params: (), as all params are extracted and used. 88 | ``` 89 | 90 | 4. You can use parameterized transforms with `torch`/`torchvision` dataset directly. 91 | However, in order to have parameters output as a single tensor of shape `[batch_size=B, num_params=P]`, we recommend wrapping your transform in `CastParamsToTensor` wrapper. 92 | More on this in [tutorial on working with torch/torchvision](https://apple.github.io/parameterized-transforms/tutorials/002-How-to-Write-Your-Own-Transforms.html). 93 | -------------------------------------------------------------------------------- /docs/source/tutorials/assets/004-crazytx-cat-aug-1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/004-crazytx-cat-aug-1.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/004-crazytx-cat-aug-2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/004-crazytx-cat-aug-2.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/004-crazytx-cat-orig.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/004-crazytx-cat-orig.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/004-pnrce-cat-aug-1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/004-pnrce-cat-aug-1.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/004-pnrce-cat-aug-2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/004-pnrce-cat-aug-2.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/004-pnrce-cat-orig.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/004-pnrce-cat-orig.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/cat-aug-0.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/cat-aug-0.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/cat-aug-1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/cat-aug-1.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/cat-aug-2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/cat-aug-2.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/cat-aug-3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/cat-aug-3.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/cat-aug-4.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/cat-aug-4.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/cat-aug-5.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/cat-aug-5.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/cat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/cat.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/dog-aug-0.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/dog-aug-0.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/dog-aug-1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/dog-aug-1.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/dog-aug-2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/dog-aug-2.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/dog-aug-3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/dog-aug-3.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/dog-aug-4.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/dog-aug-4.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/dog-aug-5.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/dog-aug-5.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/dog-aug-6.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/dog-aug-6.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/dog-aug-7.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/dog-aug-7.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/dog.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/dog.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/no-dog-aug-0.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/no-dog-aug-0.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/no-dog-aug-1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/no-dog-aug-1.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/no-dog-aug-2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/no-dog-aug-2.jpeg -------------------------------------------------------------------------------- /docs/source/tutorials/assets/no-dog-aug-3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/docs/source/tutorials/assets/no-dog-aug-3.jpeg -------------------------------------------------------------------------------- /parameterized_transforms/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | from parameterized_transforms import core 19 | from parameterized_transforms import transforms 20 | from parameterized_transforms import utils 21 | 22 | 23 | ATOMIC_TRANSFORMS = { 24 | "ToTensor": transforms.ToTensor, 25 | "PILToTensor": transforms.PILToTensor, 26 | "ConvertImageDtype": transforms.ConvertImageDtype, 27 | "ToPILImage": transforms.ToPILImage, 28 | "Normalize": transforms.Normalize, 29 | "Resize": transforms.Resize, 30 | "CenterCrop": transforms.CenterCrop, 31 | "Pad": transforms.Pad, 32 | "Lambda": transforms.Lambda, 33 | "RandomCrop": transforms.RandomCrop, 34 | "RandomHorizontalFlip": transforms.RandomHorizontalFlip, 35 | "RandomVerticalFlip": transforms.RandomVerticalFlip, 36 | "RandomPerspective": transforms.RandomPerspective, 37 | "RandomResizedCrop": transforms.RandomResizedCrop, 38 | "FiveCrop": transforms.FiveCrop, 39 | "TenCrop": transforms.TenCrop, 40 | "LinearTransformation": transforms.LinearTransformation, 41 | "ColorJitter": transforms.ColorJitter, 42 | "RandomRotation": transforms.RandomRotation, 43 | "RandomAffine": transforms.RandomAffine, 44 | "Grayscale": transforms.Grayscale, 45 | "RandomGrayscale": transforms.RandomGrayscale, 46 | "RandomErasing": transforms.RandomErasing, 47 | "GaussianBlur": transforms.GaussianBlur, 48 | "RandomInvert": transforms.RandomInvert, 49 | "RandomPosterize": transforms.RandomPosterize, 50 | "RandomSolarize": transforms.RandomSolarize, 51 | "RandomAdjustSharpness": transforms.RandomAdjustSharpness, 52 | "RandomAutocontrast": transforms.RandomAutocontrast, 53 | "RandomEqualize": transforms.RandomEqualize, 54 | "ElasticTransform": transforms.ElasticTransform, 55 | } 56 | 57 | 58 | COMPOSING_TRANSFORMS = { 59 | "Compose": transforms.Compose, 60 | "RandomApply": transforms.RandomApply, 61 | "RandomOrder": transforms.RandomOrder, 62 | "RandomChoice": transforms.RandomChoice, 63 | } 64 | 65 | 66 | TRANSFORMS = dict( 67 | tuple(ATOMIC_TRANSFORMS.items()) 68 | + tuple(COMPOSING_TRANSFORMS.items()) 69 | ) 70 | -------------------------------------------------------------------------------- /parameterized_transforms/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | import enum 19 | import typing as t 20 | 21 | from parameterized_transforms.core import SCALAR_TYPE 22 | 23 | SCALAR_RANGE_TYPE = t.Union[t.List[SCALAR_TYPE], t.Tuple[SCALAR_TYPE, SCALAR_TYPE]] 24 | 25 | 26 | # -------------------------------------------------------------------------------- 27 | # Transforms-related functionality. 28 | # -------------------------------------------------------------------------------- 29 | 30 | 31 | def get_total_params_count(transforms: t.Any) -> int: 32 | """Returns the total number of processed parameters for the given 33 | collection of core transformations. 34 | 35 | :param transforms: The collection of core transforms. 36 | 37 | :return: The total number of processed parameters. 38 | """ 39 | param_count: int = 0 40 | 41 | if isinstance(transforms, list) or isinstance(transforms, tuple): 42 | 43 | for transform in transforms: 44 | 45 | try: 46 | transform_param_count = transform.param_count 47 | except Exception as e: 48 | raise AttributeError( 49 | "ERROR | `param_count` access " 50 | f"for transform: {transform} failed; " 51 | f"hit error:\n{e}" 52 | ) 53 | 54 | param_count += transform_param_count 55 | 56 | elif isinstance(transforms, dict): 57 | 58 | for _, transform in transforms.items(): 59 | 60 | try: 61 | transform_param_count = transform.param_count 62 | except Exception as e: 63 | raise AttributeError( 64 | f"ERROR | `param_count` access " 65 | f"for transforms: {transform} failed; " 66 | f"hit error:\n{e}" 67 | ) 68 | 69 | param_count += transform_param_count 70 | 71 | else: 72 | 73 | try: 74 | param_count += transforms.param_count 75 | except Exception as e: 76 | raise AttributeError( 77 | "ERROR | `param_count` access" 78 | f"for transforms: {transforms} failed; " 79 | f"hit error:\n{e}" 80 | ) 81 | 82 | return param_count 83 | 84 | 85 | # -------------------------------------------------------------------------------- 86 | # String representation manipulations. 87 | # -------------------------------------------------------------------------------- 88 | 89 | 90 | def indent(data: str, indentor: str = " ", connector: str = "\n") -> str: 91 | """Indents the given string data with given `indentor`. 92 | 93 | :param data: The data to indent. 94 | :param indentor: The indenting character. 95 | DEFAULT: `" "` (two spaces). 96 | :param connector: The string used to concatenate given components. 97 | DEFAULT: `"\n"`. 98 | 99 | :return: The indented data. 100 | 101 | Careful not to have new-line characters in the `indentor`. This is 102 | currently allowed but NOT tested, use at your own risk. 103 | """ 104 | lines = string_to_components(string=data, separator=connector) 105 | indented_lines = indent_lines(lines=lines, indentor=indentor) 106 | indented_data = components_to_string(components=indented_lines, connector=connector) 107 | 108 | return indented_data 109 | 110 | 111 | def indent_lines(lines: t.List[str], indentor: str = " ") -> t.List[str]: 112 | """Adds as prefix the `indentor` string to each of the given lines. 113 | 114 | :param lines: The lines of the content. 115 | :param indentor: The indenting character. 116 | DEFAULT: `" "` (two spaces). 117 | 118 | :return: The indented lines. 119 | 120 | Careful not to have new-line characters in the `indentor`. This is 121 | currently allowed but NOT tested, use at your own risk. 122 | """ 123 | 124 | return [indent_line(line=line, indentor=indentor) for line in lines] 125 | 126 | 127 | def indent_line(line: str, indentor: str = " ") -> str: 128 | """Indents the given `string` using the given `indentor`. 129 | 130 | :param line: The string to be indented. 131 | :param indentor: The indenting string. 132 | 133 | :returns: The indented string. 134 | """ 135 | 136 | return f"{indentor}{line}" 137 | 138 | 139 | def string_to_components(string: str, separator: str = "\n") -> t.List[str]: 140 | """Split a given string into componenets based on given `separator`. 141 | 142 | :param string: The input string. 143 | :param separator: The string to split the given string into components. 144 | 145 | :returns: The components extracted from this string. 146 | """ 147 | 148 | return string.split(separator) 149 | 150 | 151 | def components_to_string(components: t.List[str], connector: str = ",\n") -> str: 152 | """Concatenate given components into a string using given `connector`. 153 | 154 | :param components: The components of a string. 155 | :param connector: The string used to concatenate given components. 156 | 157 | :returns: The concatenated string made from the given `components` using 158 | the given `connector` string. 159 | """ 160 | return connector.join(components) 161 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Apple Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | # Metadata 17 | # ============================================================================ 18 | [project] 19 | 20 | name = "parameterized_transforms" 21 | version = "1.0.0" 22 | description = "Transforms that provide reproducible access to parameterization information." 23 | authors = [ 24 | {name="Eeshan Gunesh Dhekane", email="eeshangunesh_dhekane@apple.com"} 25 | ] 26 | license= {"text"="Apache License, Version 2.0"} 27 | readme = "README.md" 28 | requires-python = ">=3.9.6" # Improve this. 29 | dependencies = [ 30 | "numpy", 31 | "torch", 32 | "torchvision" 33 | ] 34 | 35 | 36 | [tool.setuptools] 37 | py-modules = ["parameterized_transforms"] 38 | 39 | 40 | [project.urls] 41 | Repository = "https://github.com/apple/parameterized-transforms" 42 | 43 | 44 | [project.optional-dependencies] 45 | test = [ 46 | "pytest", 47 | "pytest-xdist", 48 | "pytest-cov", 49 | "pytest-memray", 50 | "coverage[toml]" 51 | ] 52 | 53 | 54 | # Pytest options 55 | # ============================================================================ 56 | [tool.pytest.ini_options] 57 | # This determines where test are found 58 | testpaths = ["tests/"] 59 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==2.0.2 2 | torch==2.6.0 3 | torchvision==0.21.0 -------------------------------------------------------------------------------- /tests/test_atomic_transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/parameterized-transforms/fb7efd844b970365f03e26872da789a63a3f9951/tests/test_atomic_transforms/__init__.py -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_CenterCrop.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | 25 | import numpy as np 26 | 27 | import PIL.Image as Image 28 | import PIL.ImageChops as ImageChops 29 | 30 | from itertools import product 31 | 32 | import typing as t 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "size, to_size, channels, tx_mode", 37 | list( 38 | product( 39 | [ 40 | ( 41 | 224, 42 | 224, 43 | ), 44 | ( 45 | 32, 46 | 32, 47 | ), 48 | ( 49 | 28, 50 | 28, 51 | ), 52 | ], 53 | [ 54 | [224, 224], 55 | [32, 32], 56 | [28, 28], 57 | (224, 224), 58 | (32, 32), 59 | (28, 28), 60 | (224, 32), 61 | [32, 28], 62 | (224,), 63 | (32,), 64 | (28,), 65 | [ 66 | 224, 67 | ], 68 | [ 69 | 32, 70 | ], 71 | [ 72 | 28, 73 | ], 74 | 224, 75 | 32, 76 | 28, 77 | ], 78 | [ 79 | (None,), 80 | (3,), 81 | (4,), 82 | ], 83 | ["CASCADE", "CONSUME"], 84 | ) 85 | ), 86 | ) 87 | def test_tx_on_PIL_images( 88 | size: t.Tuple[int], 89 | to_size: t.Union[t.Tuple[int], t.List[int], int], 90 | channels: t.Tuple[int], 91 | tx_mode: str, 92 | ) -> None: 93 | 94 | if channels == (None,): 95 | channels = () 96 | 97 | size = size + channels 98 | 99 | tx = ptx.CenterCrop(size=to_size, tx_mode=tx_mode) 100 | 101 | img = Image.fromarray( 102 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 103 | ) 104 | 105 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 106 | 107 | aug_1, params_1 = tx(img, orig_params) 108 | 109 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 110 | 111 | aug_3, params_3 = tx.consume_transform(img, orig_params) 112 | 113 | assert orig_params == params_1 114 | assert orig_params == params_2 115 | assert orig_params == params_3 116 | 117 | assert not ImageChops.difference(aug_1, aug_2).getbbox() 118 | assert not ImageChops.difference(aug_1, aug_3).getbbox() 119 | 120 | id_params = tx.get_default_params(img=img, processed=True) 121 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 122 | 123 | if img.size == id_aug_img.size: 124 | assert id_aug_rem_params == () 125 | assert not ImageChops.difference(id_aug_img, img).getbbox() 126 | 127 | 128 | @pytest.mark.parametrize( 129 | "size, to_size, channels, tx_mode", 130 | list( 131 | product( 132 | [ 133 | ( 134 | 224, 135 | 224, 136 | ), 137 | ( 138 | 32, 139 | 32, 140 | ), 141 | ( 142 | 28, 143 | 28, 144 | ), 145 | ], 146 | [ 147 | [224, 224], 148 | [32, 32], 149 | [28, 28], 150 | (224, 224), 151 | (32, 32), 152 | (28, 28), 153 | (224, 32), 154 | [32, 28], 155 | (224,), 156 | (32,), 157 | (28,), 158 | [ 159 | 224, 160 | ], 161 | [ 162 | 32, 163 | ], 164 | [ 165 | 28, 166 | ], 167 | 224, 168 | 32, 169 | 28, 170 | ], 171 | [ 172 | (None,), 173 | (1,), 174 | (3,), 175 | (4,), 176 | ], 177 | ["CASCADE", "CONSUME"], 178 | ) 179 | ), 180 | ) 181 | def test_tx_on_torch_tensors( 182 | size: t.Tuple[int], 183 | to_size: t.Union[t.Tuple[int], t.List[int], int], 184 | channels: t.Tuple[int], 185 | tx_mode: str, 186 | ) -> None: 187 | 188 | if channels == (None,): 189 | channels = () 190 | 191 | size = channels + size 192 | 193 | tx = ptx.CenterCrop(size=to_size, tx_mode=tx_mode) 194 | 195 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size)) 196 | 197 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 198 | 199 | aug_1, params_1 = tx(img, orig_params) 200 | 201 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 202 | 203 | aug_3, params_3 = tx.consume_transform(img, orig_params) 204 | 205 | assert orig_params == params_1 206 | assert orig_params == params_2 207 | assert orig_params == params_3 208 | 209 | assert torch.all(torch.eq(aug_1, aug_2)) 210 | assert torch.all(torch.eq(aug_1, aug_3)) 211 | 212 | id_params = tx.get_default_params(img=img, processed=True) 213 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 214 | 215 | if img.shape == id_aug_img.shape: 216 | assert id_aug_rem_params == () 217 | assert torch.all(torch.eq(id_aug_img, img)) 218 | 219 | 220 | # Main. 221 | if __name__ == "__main__": 222 | pass 223 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_ColorJitter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | import parameterized_transforms.core as ptc 25 | 26 | import numpy as np 27 | 28 | import PIL.Image as Image 29 | import PIL.ImageChops as ImageChops 30 | 31 | from itertools import product 32 | 33 | import typing as t 34 | 35 | 36 | @pytest.mark.parametrize( 37 | "size, brightness, contrast, saturation, hue, channels, tx_mode, default_params_mode", 38 | list( 39 | product( 40 | [ 41 | ( 42 | 224, 43 | 224, 44 | ), 45 | ], 46 | [ 47 | 0.5, 48 | [0.0, 3.0], 49 | ], 50 | [ 51 | 0.5, 52 | [0.0, 3.0], 53 | ], 54 | [ 55 | 0.5, 56 | [0.0, 3.0], 57 | ], 58 | [ 59 | 0.247, 60 | [0.1, 0.3], 61 | ], 62 | [(3,), (4,), (None,)], 63 | ["CASCADE", "CONSUME"], 64 | [ 65 | ptc.DefaultParamsMode.UNIQUE, 66 | ptc.DefaultParamsMode.RANDOMIZED, 67 | ], 68 | ) 69 | ), 70 | ) 71 | def test_tx_on_PIL_images( 72 | size: t.Tuple[int], 73 | brightness: t.Union[float, t.List[float], t.Tuple[float, float]], 74 | contrast: t.Union[float, t.List[float], t.Tuple[float, float]], 75 | saturation: t.Union[float, t.List[float], t.Tuple[float, float]], 76 | hue: t.Union[float, t.List[float], t.Tuple[float, float]], 77 | channels: t.Tuple[int], 78 | tx_mode: str, 79 | default_params_mode: ptc.DefaultParamsMode, 80 | ) -> None: 81 | if channels == (None,): 82 | channels = () 83 | 84 | size = size + channels 85 | 86 | tx = ptx.ColorJitter( 87 | brightness=brightness, 88 | contrast=contrast, 89 | saturation=saturation, 90 | hue=hue, 91 | tx_mode=tx_mode, 92 | default_params_mode=default_params_mode, 93 | ) 94 | 95 | img = Image.fromarray( 96 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 97 | ) 98 | 99 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 100 | 101 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 102 | assert len(params_1) - len(orig_params) == tx.param_count 103 | 104 | orig_params = () 105 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 106 | aug_3, params_3 = tx.consume_transform(img, params_2) 107 | 108 | assert orig_params == params_3 109 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 110 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 111 | 112 | id_params = tx.get_default_params(img=img, processed=True) 113 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 114 | assert id_aug_rem_params == () 115 | assert not ImageChops.difference(id_aug_img, img).getbbox() 116 | 117 | 118 | @pytest.mark.parametrize( 119 | "size, brightness, contrast, saturation, hue, channels, tx_mode, default_params_mode", 120 | list( 121 | product( 122 | [ 123 | ( 124 | 224, 125 | 224, 126 | ), 127 | ], 128 | [ 129 | 0.5, 130 | [0.0, 3.0], 131 | ], 132 | [ 133 | 0.5, 134 | [0.0, 3.0], 135 | ], 136 | [ 137 | 0.5, 138 | [0.0, 3.0], 139 | ], 140 | [ 141 | 0.247, 142 | [0.1, 0.3], 143 | ], 144 | [ 145 | (1,), 146 | (3,), 147 | # (4,), # This creates issues 148 | # (None,), # This creates issues 149 | ], 150 | ["CASCADE", "CONSUME"], 151 | [ 152 | ptc.DefaultParamsMode.UNIQUE, 153 | ptc.DefaultParamsMode.RANDOMIZED, 154 | ], 155 | ) 156 | ), 157 | ) 158 | def test_tx_on_torch_tensors( 159 | size: t.Tuple[int], 160 | brightness: t.Union[float, t.List[float], t.Tuple[float, float]], 161 | contrast: t.Union[float, t.List[float], t.Tuple[float, float]], 162 | saturation: t.Union[float, t.List[float], t.Tuple[float, float]], 163 | hue: t.Union[float, t.List[float], t.Tuple[float, float]], 164 | channels: t.Tuple[int], 165 | tx_mode: str, 166 | default_params_mode: ptc.DefaultParamsMode, 167 | ) -> None: 168 | if channels == (None,): 169 | channels = () 170 | 171 | size = channels + size 172 | 173 | tx = ptx.ColorJitter( 174 | brightness=brightness, 175 | contrast=contrast, 176 | saturation=saturation, 177 | hue=hue, 178 | tx_mode=tx_mode, 179 | default_params_mode=default_params_mode, 180 | ) 181 | 182 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size)) 183 | 184 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 185 | 186 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 187 | assert len(params_1) - len(orig_params) == tx.param_count 188 | 189 | orig_params = () 190 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 191 | aug_3, params_3 = tx.consume_transform(img, params_2) 192 | 193 | assert orig_params == params_3 194 | assert torch.all(torch.eq(aug_2, aug_3)) 195 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 196 | 197 | id_params = tx.get_default_params(img=img, processed=True) 198 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 199 | assert id_aug_rem_params == () 200 | assert torch.all(torch.eq(id_aug_img, img)) 201 | 202 | 203 | # Main. 204 | if __name__ == "__main__": 205 | pass 206 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_ConvertImageDtype.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | import torch 21 | 22 | import parameterized_transforms.transforms as ptx 23 | 24 | import numpy as np 25 | 26 | import PIL.Image as Image 27 | 28 | from itertools import product 29 | 30 | import typing as t 31 | 32 | 33 | @pytest.mark.parametrize( 34 | "dtype, mode", 35 | list( 36 | product( 37 | [ 38 | torch.int, 39 | torch.int8, 40 | torch.int16, 41 | torch.int32, 42 | torch.int64, 43 | torch.float, 44 | torch.float16, 45 | torch.float32, 46 | torch.float64, 47 | ], 48 | ["CASCADE", "CONSUME"], 49 | ) 50 | ), 51 | ) 52 | def test_tx_on_PIL_images(dtype: t.Any, mode: str) -> None: 53 | 54 | with pytest.raises(TypeError): 55 | 56 | tx = ptx.ConvertImageDtype(dtype=dtype, tx_mode=mode) 57 | 58 | img = Image.fromarray( 59 | np.random.randint(low=0, high=256, size=[224, 224, 3]).astype(np.uint8) 60 | ) 61 | 62 | prev_params = (1, 0.0, -2.0, 3, -4) 63 | 64 | aug_img, params = tx(img, prev_params) 65 | 66 | 67 | @pytest.mark.parametrize( 68 | "size, channels, channel_first, high, high_type, to_dtype, tx_mode", 69 | list( 70 | product( 71 | [ 72 | ( 73 | 224, 74 | 224, 75 | ), 76 | ], 77 | [ 78 | (None,), 79 | (1,), 80 | (3,), 81 | (4,), 82 | ], 83 | [True, False], 84 | [1.0, 256], 85 | [ 86 | np.float16, 87 | np.float32, 88 | np.float64, 89 | np.uint8, 90 | np.int8, 91 | np.int16, 92 | np.int32, 93 | np.int64, 94 | ], 95 | [ 96 | torch.int, 97 | torch.int8, 98 | torch.int16, 99 | torch.int32, 100 | torch.int64, 101 | torch.float, 102 | torch.float16, 103 | torch.float32, 104 | torch.float64, 105 | ], 106 | ["CASCADE", "CONSUME"], 107 | ) 108 | ), 109 | ) 110 | def test_tx_on_ndarrays( 111 | size: t.Tuple[int], 112 | channels: t.Tuple[int], 113 | channel_first: bool, 114 | high: t.Union, 115 | high_type: t.Any, 116 | to_dtype: t.Any, 117 | tx_mode: str, 118 | ) -> None: 119 | if channels == (None,): 120 | channels = () 121 | 122 | if channel_first: 123 | size = channels + size 124 | else: 125 | size = size + channels 126 | 127 | tx = ptx.ConvertImageDtype(dtype=to_dtype) 128 | 129 | if high == 1.0: 130 | img = np.random.uniform(low=0, high=1.0, size=size).astype(high_type) 131 | else: 132 | img = np.random.randint(low=0, high=256, size=size).astype(np.uint8) 133 | 134 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 135 | 136 | with pytest.raises(TypeError): 137 | aug_1, params_1 = tx(img, orig_params) 138 | 139 | 140 | @pytest.mark.parametrize( 141 | "size, channels, channel_first, high, high_type, to_dtype, tx_mode", 142 | list( 143 | product( 144 | [ 145 | ( 146 | 224, 147 | 224, 148 | ), 149 | ], 150 | [ 151 | (None,), 152 | (1,), 153 | (3,), 154 | (4,), 155 | ], 156 | [True, False], 157 | [1.0, 256], 158 | [ 159 | np.float16, 160 | np.float32, 161 | np.float64, 162 | np.uint8, 163 | np.int8, 164 | np.int16, 165 | np.int32, 166 | np.int64, 167 | ], 168 | [ 169 | torch.int, 170 | torch.int8, 171 | torch.int16, 172 | torch.int32, 173 | torch.int64, 174 | torch.float, 175 | torch.float16, 176 | torch.float32, 177 | torch.float64, 178 | ], 179 | ["CASCADE", "CONSUME"], 180 | ) 181 | ), 182 | ) 183 | def test_tx_on_torch_tensors( 184 | size: t.Tuple[int], 185 | channels: t.Tuple[int], 186 | channel_first: bool, 187 | high: t.Union, 188 | high_type: t.Any, 189 | to_dtype: t.Any, 190 | tx_mode: str, 191 | ) -> None: 192 | class DummyContext(object): 193 | def __enter__(self): 194 | pass 195 | 196 | def __exit__(self, exc_type, exc_val, exc_tb): 197 | pass 198 | 199 | if channels == (None,): 200 | channels = () 201 | 202 | if channel_first: 203 | size = channels + size 204 | else: 205 | size = size + channels 206 | 207 | tx = ptx.ConvertImageDtype(dtype=to_dtype) 208 | 209 | if high == 1.0: 210 | img = torch.from_numpy( 211 | np.random.uniform(low=0, high=1.0, size=size).astype(high_type) 212 | ) 213 | else: 214 | img = torch.from_numpy( 215 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 216 | ) 217 | 218 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 219 | 220 | # Check dtypes. 221 | img_dtype_name = img.dtype 222 | out_dtype_name = to_dtype 223 | 224 | if (img_dtype_name, to_dtype) in [ 225 | (torch.float32, torch.int32), 226 | (torch.float32, torch.int64), 227 | (torch.float64, torch.int64), 228 | ]: 229 | context = pytest.raises(RuntimeError) 230 | else: 231 | context = DummyContext() 232 | 233 | with context: 234 | 235 | aug_1, params_1 = tx(img, orig_params) 236 | 237 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 238 | 239 | aug_3, params_3 = tx.consume_transform(img, orig_params) 240 | 241 | assert len(list(aug_1.shape)) == 3 or len(list(aug_1.shape)) == 2 242 | 243 | assert orig_params == params_1 244 | assert orig_params == params_2 245 | assert orig_params == params_3 246 | 247 | assert torch.all(torch.eq(aug_1, aug_2)) 248 | assert torch.all(torch.eq(aug_1, aug_3)) 249 | 250 | id_params = tx.get_default_params(img=img, processed=True) 251 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 252 | if id_aug_img.dtype == img.dtype: 253 | assert id_aug_rem_params == () 254 | assert torch.all(torch.eq(id_aug_img, img)) 255 | 256 | 257 | # Main. 258 | if __name__ == "__main__": 259 | pass 260 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_ElasticTransform.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | from itertools import product 19 | 20 | import numpy as np 21 | 22 | import parameterized_transforms.transforms as ptx 23 | import parameterized_transforms.core as ptc 24 | 25 | import pytest 26 | 27 | import PIL.Image as Image 28 | import PIL.ImageChops as ImageChops 29 | 30 | import torch 31 | import torchvision.transforms.functional as tv_fn 32 | 33 | import typing as t 34 | 35 | 36 | @pytest.mark.parametrize( 37 | "image_shape, alpha, sigma, interpolation, fill, tx_mode, channels", 38 | list( 39 | product( 40 | [ 41 | ( 42 | 224, 43 | 224, 44 | ), 45 | ], # image_shape 46 | [ 47 | 0.0, 48 | 50.0, 49 | [25.0, 100.0], 50 | ], # alpha 51 | [ 52 | 1.0, 53 | 0.5, 54 | 5.0, 55 | ], # sigma 56 | [ 57 | tv_fn.InterpolationMode.NEAREST, 58 | tv_fn.InterpolationMode.BILINEAR, 59 | tv_fn.InterpolationMode.BICUBIC, 60 | # # The modes below give error in applying displacements. 61 | # # tv_fn.InterpolationMode.BOX, 62 | # # tv_fn.InterpolationMode.LANCZOS, 63 | # # tv_fn.InterpolationMode.HAMMING, 64 | ], # interpolation 65 | [0.0,], # fill 66 | ["CASCADE", "CONSUME"], # tx_mode 67 | [(3,), (4,), (None,)], # channels 68 | ) 69 | ), 70 | ) 71 | def test_tx_on_PIL_images( 72 | image_shape: t.Union[int, t.List[int], t.Tuple[int, int]], 73 | alpha: t.Union[float, t.List[float], t.Tuple[float, float]], 74 | sigma: t.Union[float, t.List[float], t.Tuple[float, float]], 75 | interpolation: tv_fn.InterpolationMode, 76 | fill: t.Union[float, t.Sequence[float]], 77 | tx_mode: ptc.TRANSFORM_MODE_TYPE, 78 | channels: t.Tuple[t.Optional[int]], 79 | ): 80 | if channels == (None,): 81 | channels = () 82 | 83 | size = image_shape + channels 84 | 85 | tx = ptx.ElasticTransform( 86 | image_shape=image_shape, 87 | alpha=alpha, 88 | sigma=sigma, 89 | interpolation=interpolation, 90 | fill=fill, 91 | tx_mode=tx_mode, 92 | ) 93 | 94 | img = Image.fromarray( 95 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 96 | ) 97 | 98 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 99 | 100 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 101 | assert len(params_1) - len(orig_params) == tx.param_count 102 | 103 | orig_params = () 104 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 105 | aug_3, params_3 = tx.consume_transform(img, params_2) 106 | 107 | assert orig_params == params_3 108 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 109 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 110 | 111 | id_params = tx.get_default_params(img=img, processed=True) 112 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 113 | assert id_aug_rem_params == () 114 | assert not ImageChops.difference(id_aug_img, img).getbbox() 115 | 116 | 117 | if __name__ == "__main__": 118 | pass -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_FiveCrop.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | 25 | import numpy as np 26 | 27 | import PIL.Image as Image 28 | import PIL.ImageChops as ImageChops 29 | 30 | from itertools import product 31 | 32 | import typing as t 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "size, to_size, channels, tx_mode", 37 | list( 38 | product( 39 | [ 40 | ( 41 | 224, 42 | 224, 43 | ), 44 | ( 45 | 32, 46 | 32, 47 | ), 48 | ], 49 | [ 50 | # 224, # This gives error. 51 | 28, 52 | 24, 53 | 3, 54 | 8, 55 | [24, 24], 56 | [3, 3], 57 | [8, 8], 58 | (24, 24), 59 | (3, 3), 60 | (8, 8), 61 | [24, 3], 62 | [3, 8], 63 | (8, 24), 64 | ], 65 | [(3,), (4,), (None,)], 66 | ["CASCADE", "CONSUME"], 67 | ) 68 | ), 69 | ) 70 | def test_tx_on_PIL_images( 71 | size: t.Tuple[int], 72 | to_size: t.Union[int, t.List[int], t.Tuple[int, int]], 73 | channels: t.Tuple[int], 74 | tx_mode: str, 75 | ) -> None: 76 | if channels == (None,): 77 | channels = () 78 | 79 | size = size + channels 80 | 81 | tx = ptx.FiveCrop( 82 | size=to_size, 83 | tx_mode=tx_mode, 84 | ) 85 | 86 | img = Image.fromarray( 87 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 88 | ) 89 | 90 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 91 | 92 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 93 | assert len(params_1) - len(orig_params) == tx.param_count 94 | 95 | orig_params = () 96 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 97 | aug_3, params_3 = tx.consume_transform(img, params_2) 98 | 99 | assert orig_params == params_3 100 | assert all( 101 | [ 102 | not ImageChops.difference(aug_2_component, aug_3_component).getbbox() 103 | for aug_2_component, aug_3_component in zip(aug_2, aug_3) 104 | ] 105 | ) 106 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 107 | 108 | id_params = tx.get_default_params(img=img, processed=True) 109 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 110 | if img.size == id_aug_img[0].size: 111 | assert all( 112 | not ImageChops.difference(img, an_id_aug_img).getbbox() 113 | for an_id_aug_img in id_aug_img 114 | ) 115 | assert id_aug_rem_params == () 116 | 117 | 118 | @pytest.mark.parametrize( 119 | "size, to_size, channels, tx_mode", 120 | list( 121 | product( 122 | [ 123 | ( 124 | 224, 125 | 224, 126 | ), 127 | ( 128 | 32, 129 | 32, 130 | ), 131 | ], 132 | [ 133 | # 224, # This gives error as crop size > image size is possible 134 | 28, 135 | 24, 136 | 3, 137 | 8, 138 | [24, 24], 139 | [3, 3], 140 | [8, 8], 141 | (24, 24), 142 | (3, 3), 143 | (8, 8), 144 | [24, 3], 145 | [3, 8], 146 | (8, 24), 147 | ], 148 | [(3,), (4,), (None,)], 149 | ["CASCADE", "CONSUME"], 150 | ) 151 | ), 152 | ) 153 | def test_tx_on_torch_tensors( 154 | size: t.Tuple[int], 155 | to_size: t.Union[int, t.List[int], t.Tuple[int, int]], 156 | channels: t.Tuple[int], 157 | tx_mode: str, 158 | ) -> None: 159 | if channels == (None,): 160 | channels = () 161 | 162 | size = channels + size 163 | 164 | tx = ptx.FiveCrop( 165 | size=to_size, 166 | tx_mode=tx_mode, 167 | ) 168 | 169 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size)) 170 | 171 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 172 | 173 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 174 | assert len(params_1) - len(orig_params) == tx.param_count 175 | 176 | orig_params = () 177 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 178 | aug_3, params_3 = tx.consume_transform(img, params_2) 179 | 180 | assert orig_params == params_3 181 | assert all( 182 | [ 183 | torch.all(torch.eq(aug_2_component, aug_3_component)) 184 | for aug_2_component, aug_3_component in zip(aug_2, aug_3) 185 | ] 186 | ) 187 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 188 | 189 | id_params = tx.get_default_params(img=img, processed=True) 190 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 191 | if img.shape == id_aug_img[0].shape: 192 | assert all( 193 | torch.all(torch.eq(img, an_id_aug_img)) for an_id_aug_img in id_aug_img 194 | ) 195 | assert id_aug_rem_params == () 196 | 197 | 198 | # Main. 199 | if __name__ == "__main__": 200 | pass 201 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_GaussianBlur.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | 25 | import numpy as np 26 | 27 | import PIL.Image as Image 28 | import PIL.ImageChops as ImageChops 29 | 30 | from itertools import product 31 | 32 | import typing as t 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "size, kernel_size, sigma, channels, tx_mode", 37 | list( 38 | product( 39 | [ 40 | ( 41 | 224, 42 | 224, 43 | ), 44 | ], 45 | [ 46 | 1, 47 | 3, 48 | # 4, # kernel size should be odd 49 | 5, 50 | 11, 51 | ], 52 | [ 53 | 1, 54 | 0.1, 55 | 0.01, 56 | 10.0, 57 | ], 58 | [(3,), (4,), (None,)], 59 | ["CASCADE", "CONSUME"], 60 | ) 61 | ), 62 | ) 63 | def test_tx_on_PIL_images( 64 | size: t.Tuple[int], 65 | kernel_size: t.Union[int, t.List[int], t.Tuple[int, int]], 66 | sigma: t.Union[float, t.List[float], t.Tuple[float, float]], 67 | channels: t.Tuple[int], 68 | tx_mode: str, 69 | ) -> None: 70 | 71 | if channels == (None,): 72 | channels = () 73 | 74 | size = size + channels 75 | 76 | tx = ptx.GaussianBlur( 77 | kernel_size=kernel_size, 78 | sigma=sigma, 79 | tx_mode=tx_mode, 80 | ) 81 | 82 | img = Image.fromarray( 83 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 84 | ) 85 | 86 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 87 | 88 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 89 | assert len(params_1) - len(orig_params) == tx.param_count 90 | 91 | orig_params = () 92 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 93 | aug_3, params_3 = tx.consume_transform(img, params_2) 94 | 95 | assert orig_params == params_3 96 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 97 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 98 | 99 | id_params = tx.get_default_params(img=img, processed=True) 100 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 101 | assert id_aug_rem_params == () 102 | assert not ImageChops.difference(id_aug_img, img).getbbox() 103 | 104 | 105 | @pytest.mark.parametrize( 106 | "size, kernel_size, sigma, channels, tx_mode", 107 | list( 108 | product( 109 | [ 110 | ( 111 | 224, 112 | 224, 113 | ), 114 | ], 115 | [ 116 | 1, 117 | 3, 118 | # 4, # kernel size should be odd 119 | 5, 120 | 11, 121 | ], 122 | [ 123 | 1, 124 | 0.1, 125 | 0.01, 126 | 10.0, 127 | ], 128 | [ 129 | (3,), 130 | (4,), 131 | ], 132 | ["CASCADE", "CONSUME"], 133 | ) 134 | ), 135 | ) 136 | def test_tx_on_torch_tensors( 137 | size: t.Tuple[int], 138 | kernel_size: t.Union[int, t.List[int], t.Tuple[int, int]], 139 | sigma: t.Union[float, t.List[float], t.Tuple[float, float]], 140 | channels: t.Tuple[int], 141 | tx_mode: str, 142 | ) -> None: 143 | 144 | if channels == (None,): 145 | channels = () 146 | 147 | size = channels + size 148 | 149 | tx = ptx.GaussianBlur( 150 | kernel_size=kernel_size, 151 | sigma=sigma, 152 | tx_mode=tx_mode, 153 | ) 154 | 155 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size)) 156 | 157 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 158 | 159 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 160 | assert len(params_1) - len(orig_params) == tx.param_count 161 | 162 | orig_params = () 163 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 164 | aug_3, params_3 = tx.consume_transform(img, params_2) 165 | 166 | assert orig_params == params_3 167 | assert torch.all(torch.eq(aug_2, aug_3)) 168 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 169 | 170 | id_params = tx.get_default_params(img=img, processed=True) 171 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 172 | assert id_aug_rem_params == () 173 | assert torch.all(torch.eq(id_aug_img, img)) 174 | 175 | 176 | # Main. 177 | if __name__ == "__main__": 178 | pass 179 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_Grayscale.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | 25 | import numpy as np 26 | 27 | import PIL.Image as Image 28 | import PIL.ImageChops as ImageChops 29 | 30 | from itertools import product 31 | 32 | import typing as t 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "size, to_channels, channels, tx_mode", 37 | list( 38 | product( 39 | [ 40 | ( 41 | 224, 42 | 224, 43 | ), 44 | ], 45 | [ 46 | 1, 47 | 3, 48 | ], 49 | [(3,), (4,), (None,)], 50 | ["CASCADE", "CONSUME"], 51 | ) 52 | ), 53 | ) 54 | def test_tx_on_PIL_images( 55 | size: t.Tuple[int], 56 | to_channels: int, 57 | channels: t.Tuple[int], 58 | tx_mode: str, 59 | ) -> None: 60 | 61 | if channels == (None,): 62 | channels = () 63 | 64 | size = size + channels 65 | 66 | tx = ptx.Grayscale( 67 | num_output_channels=to_channels, 68 | tx_mode=tx_mode, 69 | ) 70 | 71 | img = Image.fromarray( 72 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 73 | ) 74 | 75 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 76 | 77 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 78 | assert len(params_1) - len(orig_params) == tx.param_count 79 | 80 | orig_params = () 81 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 82 | aug_3, params_3 = tx.consume_transform(img, params_2) 83 | 84 | assert orig_params == params_3 85 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 86 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 87 | 88 | # # This is a transform that always alters the image. 89 | # id_params = tx.get_default_params(img=img, processed=True) 90 | # id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 91 | # assert id_aug_rem_params == () 92 | # assert not ImageChops.difference(id_aug_img, img).getbbox() 93 | 94 | 95 | @pytest.mark.parametrize( 96 | "size, to_channels, channels, tx_mode", 97 | list( 98 | product( 99 | [ 100 | ( 101 | 224, 102 | 224, 103 | ), 104 | ], 105 | [ 106 | 1, 107 | 3, 108 | ], 109 | [(3,)], 110 | ["CASCADE", "CONSUME"], 111 | ) 112 | ), 113 | ) 114 | def test_tx_on_torch_tensors( 115 | size: t.Tuple[int], 116 | to_channels: int, 117 | channels: t.Tuple[int], 118 | tx_mode: str, 119 | ) -> None: 120 | 121 | if channels == (None,): 122 | channels = () 123 | 124 | size = channels + size 125 | 126 | tx = ptx.Grayscale( 127 | num_output_channels=to_channels, 128 | tx_mode=tx_mode, 129 | ) 130 | 131 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size)) 132 | 133 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 134 | 135 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 136 | assert len(params_1) - len(orig_params) == tx.param_count 137 | 138 | orig_params = () 139 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 140 | aug_3, params_3 = tx.consume_transform(img, params_2) 141 | 142 | assert orig_params == params_3 143 | assert torch.all(torch.eq(aug_2, aug_3)) 144 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 145 | 146 | # # This is a transform that always alters the image. 147 | # id_params = tx.get_default_params(img=img, processed=True) 148 | # id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 149 | # assert id_aug_rem_params == () 150 | # assert torch.all(torch.eq(id_aug_img, img)) 151 | 152 | 153 | # Main. 154 | if __name__ == "__main__": 155 | pass 156 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_Lambda.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | 25 | import numpy as np 26 | 27 | import PIL.Image as Image 28 | import PIL.ImageChops as ImageChops 29 | 30 | from itertools import product 31 | 32 | import typing as t 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "size, lambd, channels, tx_mode", 37 | list( 38 | product( 39 | [ 40 | ( 41 | 224, 42 | 224, 43 | ), 44 | ], 45 | [ 46 | lambda x: x, 47 | ], 48 | [(3,), (4,), (None,)], 49 | ["CASCADE", "CONSUME"], 50 | ) 51 | ), 52 | ) 53 | def test_tx_on_PIL_images( 54 | size: t.Tuple[int], 55 | lambd: t.Callable, 56 | channels: int, 57 | tx_mode: str, 58 | ) -> None: 59 | 60 | if channels == (None,): 61 | channels = () 62 | 63 | size = size + channels 64 | 65 | tx = ptx.Lambda(lambd=lambd, tx_mode=tx_mode) 66 | 67 | img = Image.fromarray( 68 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 69 | ) 70 | 71 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 72 | 73 | aug_1, params_1 = tx(img, orig_params) 74 | 75 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 76 | 77 | aug_3, params_3 = tx.consume_transform(img, orig_params) 78 | 79 | assert orig_params == params_1 80 | assert orig_params == params_2 81 | assert orig_params == params_3 82 | 83 | assert not ImageChops.difference(aug_1, aug_2).getbbox() 84 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 85 | 86 | # # This transform does NOT have identity params. 87 | # # With given identity function, we get the same image back. 88 | id_params = tx.get_default_params(img=img, processed=True) 89 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 90 | assert id_aug_rem_params == () 91 | assert not ImageChops.difference(id_aug_img, img).getbbox() 92 | 93 | 94 | @pytest.mark.parametrize( 95 | "size, lambd, channels, tx_mode", 96 | list( 97 | product( 98 | [ 99 | ( 100 | 224, 101 | 224, 102 | ), 103 | ], 104 | [ 105 | lambda x: x, 106 | ], 107 | [(3,), (4,), (None,)], 108 | ["CASCADE", "CONSUME"], 109 | ) 110 | ), 111 | ) 112 | def test_tx_on_torch_tensors( 113 | size: t.Tuple[int], 114 | lambd: t.Callable, 115 | channels: int, 116 | tx_mode: str, 117 | ) -> None: 118 | 119 | if channels == (None,): 120 | channels = () 121 | 122 | size = channels + size 123 | 124 | tx = ptx.Lambda(lambd=lambd, tx_mode=tx_mode) 125 | 126 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size)) 127 | 128 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 129 | 130 | aug_1, params_1 = tx(img, orig_params) 131 | 132 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 133 | 134 | aug_3, params_3 = tx.consume_transform(img, orig_params) 135 | 136 | assert orig_params == params_1 137 | assert orig_params == params_2 138 | assert orig_params == params_3 139 | 140 | assert torch.all(torch.eq(aug_1, aug_2)) 141 | assert torch.all(torch.eq(aug_2, aug_3)) 142 | 143 | # # This transform does NOT have identity params. 144 | # # With given identity function, we get the same image back. 145 | id_params = tx.get_default_params(img=img, processed=True) 146 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 147 | assert id_aug_rem_params == () 148 | assert torch.all(torch.eq(id_aug_img, img)) 149 | 150 | 151 | # Main. 152 | if __name__ == "__main__": 153 | pass 154 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_LinearTransformation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | 25 | import numpy as np 26 | 27 | from itertools import product 28 | 29 | import typing as t 30 | 31 | 32 | @pytest.mark.parametrize( 33 | "size, channels, tx_mode, repeats", 34 | list( 35 | product( 36 | [ 37 | ( 38 | 11, 39 | 11, 40 | ), 41 | ( 42 | 12, 43 | 12, 44 | ), 45 | ( 46 | 13, 47 | 13, 48 | ), 49 | ], 50 | [ 51 | (3,), 52 | (4,), 53 | # (None,) # This gives error. 54 | ], 55 | ["CASCADE", "CONSUME"], 56 | [ 57 | 100, 58 | ], 59 | ) 60 | ), 61 | ) 62 | def test_tx_on_torch_tensors( 63 | size: t.Tuple[int], 64 | channels: t.Tuple[int], 65 | tx_mode: str, 66 | repeats: int, 67 | ) -> None: 68 | if channels == (None,): 69 | channels = () 70 | 71 | size = channels + size 72 | 73 | dim = np.prod(size) 74 | 75 | for _ in range(repeats): 76 | 77 | tx_matrix = torch.Tensor(dim, dim).uniform_().double() 78 | tx_mean = ( 79 | torch.Tensor( 80 | dim, 81 | ) 82 | .uniform_() 83 | .double() 84 | ) 85 | 86 | tx = ptx.LinearTransformation( 87 | transformation_matrix=tx_matrix, 88 | mean_vector=tx_mean, 89 | tx_mode=tx_mode, 90 | ) 91 | 92 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size)) 93 | 94 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 95 | 96 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 97 | assert len(params_1) - len(orig_params) == tx.param_count 98 | 99 | orig_params = () 100 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 101 | aug_3, params_3 = tx.consume_transform(img, params_2) 102 | 103 | assert orig_params == params_3 104 | assert torch.all(torch.eq(aug_2, aug_3)) 105 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 106 | 107 | # # The identity params SHOULD be `sigma=0` but this does NOT work correctly in code. 108 | # id_params = tx.get_default_params(img=img, processed=True) 109 | # id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 110 | # assert id_aug_rem_params == () 111 | # assert torch.all(torch.eq(id_aug_img, img)) 112 | 113 | 114 | # Main. 115 | if __name__ == "__main__": 116 | pass 117 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_PILToTensor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | # Dependencies. 18 | import pytest 19 | import torch 20 | 21 | import parameterized_transforms.transforms as ptx 22 | 23 | import numpy as np 24 | 25 | import PIL.Image as Image 26 | 27 | from itertools import product 28 | 29 | import typing as t 30 | 31 | 32 | @pytest.mark.parametrize( 33 | "size, channels, tx_mode", 34 | list( 35 | product( 36 | [ 37 | ( 38 | 224, 39 | 224, 40 | ), 41 | ], 42 | [ 43 | (None,), 44 | (3,), 45 | (4,), 46 | ], 47 | ["CASCADE", "CONSUME"], 48 | ) 49 | ), 50 | ) 51 | def test_tx_on_PIL_images_1( 52 | size: t.Tuple[int], 53 | channels: t.Tuple[int], 54 | tx_mode: str, 55 | ) -> None: 56 | if channels == (None,): 57 | channels = () 58 | 59 | size = size + channels 60 | 61 | tx = ptx.PILToTensor() 62 | 63 | img = Image.fromarray( 64 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 65 | ) 66 | 67 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 68 | 69 | aug_1, params_1 = tx(img, orig_params) 70 | 71 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 72 | 73 | aug_3, params_3 = tx.consume_transform(img, orig_params) 74 | 75 | assert len(list(aug_1.shape)) == 3 # [C, H, W] format ensured 76 | assert aug_1.shape[0] == 1 or aug_1.shape[0] == 3 or aug_1.shape[0] == 4 77 | assert 0 <= torch.min(aug_1).item() <= 255.0 78 | assert 0 <= torch.max(aug_1).item() <= 255.0 79 | 80 | assert orig_params == params_1 81 | assert orig_params == params_2 82 | assert orig_params == params_3 83 | 84 | assert torch.all(torch.eq(aug_1, aug_2)) 85 | assert torch.all(torch.eq(aug_1, aug_3)) 86 | 87 | # # This transform does NOT have identity params. 88 | # id_params = tx.get_default_params(img=img, processed=True) 89 | # id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 90 | # assert id_aug_rem_params == () 91 | # assert not ImageChops.difference(id_aug_img, img).getbbox() 92 | 93 | 94 | @pytest.mark.parametrize( 95 | "size, channels, tx_mode", 96 | list( 97 | product( 98 | [ 99 | ( 100 | 224, 101 | 224, 102 | ), 103 | ], 104 | [ 105 | (1,), 106 | ], 107 | ["CASCADE", "CONSUME"], 108 | ) 109 | ), 110 | ) 111 | def test_tx_on_PIL_images_2( 112 | size: t.Tuple[int], 113 | channels: t.Tuple[int], 114 | tx_mode: str, 115 | ) -> None: 116 | with pytest.raises(TypeError): 117 | if channels == (None,): 118 | channels = () 119 | 120 | size = size + channels 121 | 122 | tx = ptx.PILToTensor() 123 | 124 | img = Image.fromarray( 125 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 126 | ) 127 | 128 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 129 | 130 | aug_1, params_1 = tx(img, orig_params) 131 | 132 | 133 | @pytest.mark.parametrize( 134 | "size, channels, channel_first, high, high_type, tx_mode", 135 | list( 136 | product( 137 | [ 138 | ( 139 | 224, 140 | 224, 141 | ), 142 | ], 143 | [ 144 | (None,), 145 | (1,), 146 | (3,), 147 | (4,), 148 | ], 149 | [True, False], 150 | [1.0, 256], 151 | [ 152 | np.float16, 153 | np.float32, 154 | np.float64, 155 | np.uint8, 156 | np.int8, 157 | np.int16, 158 | np.int32, 159 | np.int64, 160 | ], 161 | ["CASCADE", "CONSUME"], 162 | ) 163 | ), 164 | ) 165 | def test_tx_on_ndarrays( 166 | size: t.Tuple[int], 167 | channels: t.Tuple[int], 168 | channel_first: bool, 169 | high: t.Union, 170 | high_type: t.Any, 171 | tx_mode: str, 172 | ) -> None: 173 | if channels == (None,): 174 | channels = () 175 | 176 | if channel_first: 177 | size = channels + size 178 | else: 179 | size = size + channels 180 | 181 | tx = ptx.PILToTensor() 182 | 183 | if high == 1.0: 184 | img = np.random.uniform(low=0, high=1.0, size=size).astype(high_type) 185 | else: 186 | img = np.random.randint(low=0, high=256, size=size).astype(np.uint8) 187 | 188 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 189 | 190 | with pytest.raises(TypeError): 191 | 192 | aug_1, params_1 = tx(img, orig_params) 193 | 194 | 195 | @pytest.mark.parametrize( 196 | "size, channels, channel_first, high, high_type, tx_mode", 197 | list( 198 | product( 199 | [ 200 | ( 201 | 224, 202 | 224, 203 | ), 204 | ], 205 | [ 206 | (None,), 207 | (1,), 208 | (3,), 209 | (4,), 210 | ], 211 | [True, False], 212 | [1.0, 256], 213 | [ 214 | np.float16, 215 | np.float32, 216 | np.float64, 217 | np.uint8, 218 | np.int8, 219 | np.int16, 220 | np.int32, 221 | np.int64, 222 | ], 223 | ["CASCADE", "CONSUME"], 224 | ) 225 | ), 226 | ) 227 | def test_tx_on_torch_tensors( 228 | size: t.Tuple[int], 229 | channels: t.Tuple[int], 230 | channel_first: bool, 231 | high: t.Union, 232 | high_type: t.Any, 233 | tx_mode: str, 234 | ) -> None: 235 | if channels == (None,): 236 | channels = () 237 | 238 | if channel_first: 239 | size = channels + size 240 | else: 241 | size = size + channels 242 | 243 | tx = ptx.PILToTensor() 244 | 245 | if high == 1.0: 246 | img = torch.from_numpy( 247 | np.random.uniform(low=0, high=1.0, size=size).astype(high_type) 248 | ) 249 | else: 250 | img = torch.from_numpy( 251 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 252 | ) 253 | 254 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 255 | 256 | with pytest.raises(TypeError): 257 | 258 | aug_1, params_1 = tx(img, orig_params) 259 | 260 | 261 | # Main. 262 | if __name__ == "__main__": 263 | pass 264 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_RandomAdjustSharpness.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | 25 | import numpy as np 26 | 27 | import PIL.Image as Image 28 | import PIL.ImageChops as ImageChops 29 | 30 | from itertools import product 31 | 32 | import typing as t 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "size, sharpness_factor, p, channels, tx_mode", 37 | list( 38 | product( 39 | [ 40 | ( 41 | 224, 42 | 224, 43 | ), 44 | ], 45 | [0.1, 0.5, 1.0, 2.0], 46 | [0.0, 0.5, 1.0], 47 | [ 48 | (3,), 49 | (4,), 50 | (None,), 51 | ], 52 | ["CASCADE", "CONSUME"], 53 | ) 54 | ), 55 | ) 56 | def test_tx_on_PIL_images( 57 | size: t.Tuple[int], 58 | sharpness_factor: float, 59 | p: float, 60 | channels: t.Tuple[int], 61 | tx_mode: str, 62 | ) -> None: 63 | if channels == (None,): 64 | channels = () 65 | 66 | size = size + channels 67 | 68 | tx = ptx.RandomAdjustSharpness( 69 | sharpness_factor=sharpness_factor, 70 | p=p, 71 | tx_mode=tx_mode, 72 | ) 73 | 74 | img = Image.fromarray( 75 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 76 | ) 77 | 78 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 79 | 80 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 81 | assert len(params_1) - len(orig_params) == tx.param_count 82 | 83 | orig_params = () 84 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 85 | aug_3, params_3 = tx.consume_transform(img, params_2) 86 | 87 | assert orig_params == params_3 88 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 89 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 90 | 91 | id_params = tx.get_default_params(img=img, processed=True) 92 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 93 | assert id_aug_rem_params == () 94 | assert not ImageChops.difference(id_aug_img, img).getbbox() 95 | 96 | 97 | @pytest.mark.parametrize( 98 | "size, sharpness_factor, p, channels, tx_mode", 99 | list( 100 | product( 101 | [ 102 | ( 103 | 224, 104 | 224, 105 | ), 106 | ], 107 | [0.1, 0.5, 1.0, 2.0], 108 | [0.0, 0.5, 1.0], 109 | [ 110 | (3,), 111 | # # These give error 112 | # (4,), 113 | # (None,), 114 | ], 115 | ["CASCADE", "CONSUME"], 116 | ) 117 | ), 118 | ) 119 | def test_tx_on_torch_tensors( 120 | size: t.Tuple[int], 121 | sharpness_factor: float, 122 | p: float, 123 | channels: t.Tuple[int], 124 | tx_mode: str, 125 | ) -> None: 126 | if channels == (None,): 127 | channels = () 128 | 129 | size = channels + size 130 | 131 | tx = ptx.RandomAdjustSharpness( 132 | sharpness_factor=sharpness_factor, 133 | p=p, 134 | tx_mode=tx_mode, 135 | ) 136 | 137 | img = torch.from_numpy(np.random.randint(low=0, high=1.0, size=size)) 138 | 139 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 140 | 141 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 142 | assert len(params_1) - len(orig_params) == tx.param_count 143 | 144 | orig_params = () 145 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 146 | aug_3, params_3 = tx.consume_transform(img, params_2) 147 | 148 | assert orig_params == params_3 149 | assert torch.all(torch.eq(aug_2, aug_3)) 150 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 151 | 152 | id_params = tx.get_default_params(img=img, processed=True) 153 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 154 | assert id_aug_rem_params == () 155 | assert torch.all(torch.eq(id_aug_img, img)) 156 | 157 | 158 | # Main. 159 | if __name__ == "__main__": 160 | pass 161 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_RandomAutocontrast.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | 25 | import numpy as np 26 | 27 | import PIL.Image as Image 28 | import PIL.ImageChops as ImageChops 29 | 30 | from itertools import product 31 | 32 | import typing as t 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "size, p, channels, tx_mode", 37 | list( 38 | product( 39 | [ 40 | ( 41 | 224, 42 | 224, 43 | ), 44 | ], 45 | [0.0, 0.5, 1.0], 46 | [ 47 | (3,), 48 | # # These are not allowed 49 | # (4,), 50 | # (None,), 51 | ], 52 | ["CASCADE", "CONSUME"], 53 | ) 54 | ), 55 | ) 56 | def test_tx_on_PIL_images( 57 | size: t.Tuple[int], 58 | p: float, 59 | channels: t.Tuple[int], 60 | tx_mode: str, 61 | ) -> None: 62 | if channels == (None,): 63 | channels = () 64 | 65 | size = size + channels 66 | 67 | tx = ptx.RandomAutocontrast( 68 | p=p, 69 | tx_mode=tx_mode, 70 | ) 71 | 72 | img = Image.fromarray( 73 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 74 | ) 75 | 76 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 77 | 78 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 79 | assert len(params_1) - len(orig_params) == tx.param_count 80 | 81 | orig_params = () 82 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 83 | aug_3, params_3 = tx.consume_transform(img, params_2) 84 | 85 | assert orig_params == params_3 86 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 87 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 88 | 89 | id_params = tx.get_default_params(img=img, processed=True) 90 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 91 | assert id_aug_rem_params == () 92 | assert not ImageChops.difference(id_aug_img, img).getbbox() 93 | 94 | 95 | @pytest.mark.parametrize( 96 | "size, p, channels, tx_mode", 97 | list( 98 | product( 99 | [ 100 | ( 101 | 224, 102 | 224, 103 | ), 104 | ], 105 | [0.0, 0.5, 1.0], 106 | [ 107 | (3,), 108 | # # These are not allowed 109 | # (4,), 110 | # (None,), 111 | ], 112 | ["CASCADE", "CONSUME"], 113 | ) 114 | ), 115 | ) 116 | def test_tx_on_torch_tensors( 117 | size: t.Tuple[int], 118 | p: float, 119 | channels: t.Tuple[int], 120 | tx_mode: str, 121 | ) -> None: 122 | if channels == (None,): 123 | channels = () 124 | 125 | size = channels + size 126 | 127 | tx = ptx.RandomAutocontrast( 128 | p=p, 129 | tx_mode=tx_mode, 130 | ) 131 | 132 | img = torch.from_numpy(np.random.randint(low=0, high=1.0, size=size)) 133 | 134 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 135 | 136 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 137 | assert len(params_1) - len(orig_params) == tx.param_count 138 | 139 | orig_params = () 140 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 141 | aug_3, params_3 = tx.consume_transform(img, params_2) 142 | 143 | assert orig_params == params_3 144 | assert torch.all(torch.eq(aug_2, aug_3)) 145 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 146 | 147 | id_params = tx.get_default_params(img=img, processed=True) 148 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 149 | assert id_aug_rem_params == () 150 | assert torch.all(torch.eq(id_aug_img, img)) 151 | 152 | 153 | # Main. 154 | if __name__ == "__main__": 155 | pass 156 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_RandomEqualize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | 25 | import numpy as np 26 | 27 | import PIL.Image as Image 28 | import PIL.ImageChops as ImageChops 29 | 30 | from itertools import product 31 | 32 | import typing as t 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "size, p, channels, tx_mode", 37 | list( 38 | product( 39 | [ 40 | ( 41 | 224, 42 | 224, 43 | ), 44 | ], 45 | [0.0, 0.5, 1.0], 46 | [ 47 | (3,), 48 | # # This is 49 | # (4,), 50 | # (None,), 51 | ], 52 | ["CASCADE", "CONSUME"], 53 | ) 54 | ), 55 | ) 56 | def test_tx_on_PIL_images( 57 | size: t.Tuple[int], 58 | p: float, 59 | channels: t.Tuple[int], 60 | tx_mode: str, 61 | ) -> None: 62 | if channels == (None,): 63 | channels = () 64 | 65 | size = size + channels 66 | 67 | tx = ptx.RandomEqualize( 68 | p=p, 69 | tx_mode=tx_mode, 70 | ) 71 | 72 | img = Image.fromarray( 73 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 74 | ) 75 | 76 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 77 | 78 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 79 | assert len(params_1) - len(orig_params) == tx.param_count 80 | 81 | orig_params = () 82 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 83 | aug_3, params_3 = tx.consume_transform(img, params_2) 84 | 85 | assert orig_params == params_3 86 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 87 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 88 | 89 | id_params = tx.get_default_params(img=img, processed=True) 90 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 91 | assert id_aug_rem_params == () 92 | assert not ImageChops.difference(id_aug_img, img).getbbox() 93 | 94 | 95 | @pytest.mark.parametrize( 96 | "size, p, channels, tx_mode", 97 | list( 98 | product( 99 | [ 100 | ( 101 | 224, 102 | 224, 103 | ), 104 | ], 105 | [0.0, 0.5, 1.0], 106 | [ 107 | (3,), 108 | # # These give error 109 | # (4,), 110 | # (None,), 111 | ], 112 | ["CASCADE", "CONSUME"], 113 | ) 114 | ), 115 | ) 116 | def test_tx_on_torch_tensors( 117 | size: t.Tuple[int], 118 | p: float, 119 | channels: t.Tuple[int], 120 | tx_mode: str, 121 | ) -> None: 122 | if channels == (None,): 123 | channels = () 124 | 125 | size = channels + size 126 | 127 | tx = ptx.RandomEqualize( 128 | p=p, 129 | tx_mode=tx_mode, 130 | ) 131 | 132 | img = torch.from_numpy( 133 | np.random.uniform(low=0, high=256, size=size).astype(np.uint8) 134 | ) 135 | 136 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 137 | 138 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 139 | assert len(params_1) - len(orig_params) == tx.param_count 140 | 141 | orig_params = () 142 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 143 | aug_3, params_3 = tx.consume_transform(img, params_2) 144 | 145 | assert orig_params == params_3 146 | assert torch.all(torch.eq(aug_2, aug_3)) 147 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 148 | 149 | id_params = tx.get_default_params(img=img, processed=True) 150 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 151 | assert id_aug_rem_params == () 152 | assert torch.all(torch.eq(id_aug_img, img)) 153 | 154 | 155 | # Main. 156 | if __name__ == "__main__": 157 | pass 158 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_RandomErasing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | import parameterized_transforms.core as ptc 25 | 26 | import numpy as np 27 | 28 | from itertools import product 29 | 30 | import typing as t 31 | 32 | 33 | @pytest.mark.parametrize( 34 | "size, p, scale, ratio, value, default_params_mode, channels, tx_mode", 35 | list( 36 | product( 37 | [ 38 | ( 39 | 224, 40 | 224, 41 | ), 42 | ], 43 | # p 44 | [0.0, 0.5, 1.0], 45 | # scale 46 | [ 47 | (0.02, 0.33), 48 | ], 49 | # ratio 50 | [ 51 | (3.0 / 4, 4.0 / 3), 52 | ], 53 | # value 54 | ["random", 0, 23, 255.0], 55 | # identity params mode 56 | [ 57 | ptc.DefaultParamsMode.RANDOMIZED, 58 | ptc.DefaultParamsMode.UNIQUE, 59 | ], 60 | # channels 61 | [ 62 | (3,), 63 | (4,), 64 | ], 65 | ["CASCADE", "CONSUME"], 66 | ) 67 | ), 68 | ) 69 | def test_tx_on_torch_tensors( 70 | size: t.Tuple[int], 71 | p: float, 72 | scale: t.Union[t.Tuple[float, float], t.List[float]], 73 | ratio: t.Union[t.Tuple[float, float], t.List[float]], 74 | value: t.Union[str, int, str, t.List[int], t.Tuple[int, int, int]], 75 | default_params_mode: ptc.DefaultParamsMode, 76 | channels: t.Tuple[int], 77 | tx_mode: str, 78 | ) -> None: 79 | 80 | # PIL-support is NOT guaranteed in torchvision 81 | 82 | if channels == (None,): 83 | channels = () 84 | 85 | size = channels + size 86 | 87 | tx = ptx.RandomErasing( 88 | p=p, 89 | scale=scale, 90 | ratio=ratio, 91 | value=value, 92 | tx_mode=tx_mode, 93 | default_params_mode=default_params_mode, 94 | ) 95 | 96 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size)) 97 | 98 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 99 | 100 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 101 | assert len(params_1) - len(orig_params) == tx.param_count 102 | 103 | orig_params = () 104 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 105 | aug_3, params_3 = tx.consume_transform(img, params_2) 106 | 107 | assert orig_params == params_3 108 | 109 | if value != "random": 110 | assert torch.all(torch.eq(aug_2, aug_3)) 111 | 112 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 113 | 114 | id_params = tx.get_default_params(img=img, processed=True) 115 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 116 | if value != "random": 117 | assert id_aug_rem_params == () 118 | assert torch.all(torch.eq(id_aug_img, img)) 119 | 120 | 121 | # Main. 122 | if __name__ == "__main__": 123 | pass 124 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_RandomGrayscale.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | 25 | import numpy as np 26 | 27 | import PIL.Image as Image 28 | import PIL.ImageChops as ImageChops 29 | 30 | from itertools import product 31 | 32 | import typing as t 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "size, p, channels, tx_mode", 37 | list( 38 | product( 39 | [ 40 | ( 41 | 224, 42 | 224, 43 | ), 44 | ], 45 | [0.0, 0.5, 1.0], 46 | [ 47 | (3,), 48 | # (4,), 49 | (None,), 50 | ], 51 | ["CASCADE", "CONSUME"], 52 | ) 53 | ), 54 | ) 55 | def test_tx_on_PIL_images( 56 | size: t.Tuple[int], 57 | p: float, 58 | channels: t.Tuple[int], 59 | tx_mode: str, 60 | ) -> None: 61 | 62 | if channels == (None,): 63 | channels = () 64 | 65 | size = size + channels 66 | 67 | tx = ptx.RandomGrayscale( 68 | p=p, 69 | tx_mode=tx_mode, 70 | ) 71 | 72 | img = Image.fromarray( 73 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 74 | ) 75 | 76 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 77 | 78 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 79 | assert len(params_1) - len(orig_params) == tx.param_count 80 | 81 | orig_params = () 82 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 83 | aug_3, params_3 = tx.consume_transform(img, params_2) 84 | 85 | assert orig_params == params_3 86 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 87 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 88 | 89 | id_params = tx.get_default_params(img=img, processed=True) 90 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 91 | assert id_aug_rem_params == () 92 | assert not ImageChops.difference(id_aug_img, img).getbbox() 93 | 94 | 95 | @pytest.mark.parametrize( 96 | "size, p, channels, tx_mode", 97 | list( 98 | product( 99 | [ 100 | ( 101 | 224, 102 | 224, 103 | ), 104 | ], 105 | [0.0, 0.5, 1.0], 106 | [ 107 | (3,), 108 | ], 109 | ["CASCADE", "CONSUME"], 110 | ) 111 | ), 112 | ) 113 | def test_tx_on_torch_tensors( 114 | size: t.Tuple[int], 115 | p: int, 116 | channels: t.Tuple[int], 117 | tx_mode: str, 118 | ) -> None: 119 | 120 | if channels == (None,): 121 | channels = () 122 | 123 | size = channels + size 124 | 125 | tx = ptx.RandomGrayscale( 126 | p=p, 127 | tx_mode=tx_mode, 128 | ) 129 | 130 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size)) 131 | 132 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 133 | 134 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 135 | assert len(params_1) - len(orig_params) == tx.param_count 136 | 137 | orig_params = () 138 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 139 | aug_3, params_3 = tx.consume_transform(img, params_2) 140 | 141 | assert orig_params == params_3 142 | assert torch.all(torch.eq(aug_2, aug_3)) 143 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 144 | 145 | id_params = tx.get_default_params(img=img, processed=True) 146 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 147 | assert id_aug_rem_params == () 148 | assert torch.all(torch.eq(id_aug_img, img)) 149 | 150 | 151 | # Main. 152 | if __name__ == "__main__": 153 | pass 154 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_RandomHorizontalFlip.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | 25 | import numpy as np 26 | 27 | import PIL.Image as Image 28 | import PIL.ImageChops as ImageChops 29 | 30 | from itertools import product 31 | 32 | import typing as t 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "size, p, channels, tx_mode", 37 | list( 38 | product( 39 | [ 40 | ( 41 | 224, 42 | 224, 43 | ), 44 | ], 45 | [0.0, 0.5, 1.0], 46 | [(3,), (4,), (None,)], 47 | ["CASCADE", "CONSUME"], 48 | ) 49 | ), 50 | ) 51 | def test_tx_on_PIL_images( 52 | size: t.Tuple[int], 53 | p: float, 54 | channels: t.Tuple[int], 55 | tx_mode: str, 56 | ) -> None: 57 | 58 | # 59 | class DummyContext(object): 60 | def __enter__(self): 61 | pass 62 | 63 | def __exit__(self, exc_type, exc_val, exc_tb): 64 | pass 65 | 66 | if channels == (None,): 67 | channels = () 68 | 69 | size = size + channels 70 | 71 | tx = ptx.RandomHorizontalFlip( 72 | p=p, 73 | tx_mode=tx_mode, 74 | ) 75 | 76 | img = Image.fromarray( 77 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 78 | ) 79 | 80 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 81 | 82 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 83 | assert len(params_1) - len(orig_params) == tx.param_count 84 | 85 | orig_params = () 86 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 87 | aug_3, params_3 = tx.consume_transform(img, params_2) 88 | 89 | assert orig_params == params_3 90 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 91 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 92 | 93 | if params_2[-1] == 0: 94 | assert not ImageChops.difference(aug_2, img).getbbox() 95 | 96 | id_params = tx.get_default_params(img=img, processed=True) 97 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 98 | assert id_aug_rem_params == () 99 | assert not ImageChops.difference(id_aug_img, img).getbbox() 100 | 101 | 102 | @pytest.mark.parametrize( 103 | "size, p, channels, tx_mode", 104 | list( 105 | product( 106 | [ 107 | ( 108 | 224, 109 | 224, 110 | ), 111 | ], 112 | [0.0, 0.5, 1.0], 113 | [(3,), (4,), (None,)], 114 | ["CASCADE", "CONSUME"], 115 | ) 116 | ), 117 | ) 118 | def test_tx_on_torch_tensors( 119 | size: t.Tuple[int], 120 | p: float, 121 | channels: t.Tuple[int], 122 | tx_mode: str, 123 | ) -> None: 124 | 125 | # 126 | class DummyContext(object): 127 | def __enter__(self): 128 | pass 129 | 130 | def __exit__(self, exc_type, exc_val, exc_tb): 131 | pass 132 | 133 | if channels == (None,): 134 | channels = () 135 | 136 | size = channels + size 137 | 138 | tx = ptx.RandomHorizontalFlip( 139 | p=p, 140 | tx_mode=tx_mode, 141 | ) 142 | 143 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size)) 144 | 145 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 146 | 147 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 148 | assert len(params_1) - len(orig_params) == tx.param_count 149 | 150 | orig_params = () 151 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 152 | aug_3, params_3 = tx.consume_transform(img, params_2) 153 | 154 | assert orig_params == params_3 155 | assert torch.all(torch.eq(aug_2, aug_3)) 156 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 157 | 158 | if params_2[-1] == 0: 159 | assert torch.all(torch.eq(aug_2, img)) 160 | 161 | id_params = tx.get_default_params(img=img, processed=True) 162 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 163 | assert id_aug_rem_params == () 164 | assert torch.all(torch.eq(id_aug_img, img)) 165 | 166 | 167 | # Main. 168 | if __name__ == "__main__": 169 | pass 170 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_RandomInvert.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | 25 | import numpy as np 26 | 27 | import PIL.Image as Image 28 | import PIL.ImageChops as ImageChops 29 | 30 | from itertools import product 31 | 32 | import typing as t 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "size, p, channels, tx_mode", 37 | list( 38 | product( 39 | [ 40 | ( 41 | 224, 42 | 224, 43 | ), 44 | ], 45 | [0.0, 0.5, 1.0], 46 | [ 47 | (3,), 48 | # # These dimensions lead to error. 49 | # (4,), 50 | # (None,) 51 | ], 52 | ["CASCADE", "CONSUME"], 53 | ) 54 | ), 55 | ) 56 | def test_tx_on_PIL_images( 57 | size: t.Tuple[int], 58 | p: float, 59 | channels: t.Tuple[int], 60 | tx_mode: str, 61 | ) -> None: 62 | 63 | if channels == (None,): 64 | channels = () 65 | 66 | size = size + channels 67 | 68 | tx = ptx.RandomInvert( 69 | p=p, 70 | tx_mode=tx_mode, 71 | ) 72 | 73 | img = Image.fromarray( 74 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 75 | ) 76 | 77 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 78 | 79 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 80 | assert len(params_1) - len(orig_params) == tx.param_count 81 | 82 | orig_params = () 83 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 84 | aug_3, params_3 = tx.consume_transform(img, params_2) 85 | 86 | assert orig_params == params_3 87 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 88 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 89 | 90 | id_params = tx.get_default_params(img=img, processed=True) 91 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 92 | assert id_aug_rem_params == () 93 | assert not ImageChops.difference(id_aug_img, img).getbbox() 94 | 95 | 96 | @pytest.mark.parametrize( 97 | "size, p, channels, tx_mode", 98 | list( 99 | product( 100 | [ 101 | ( 102 | 224, 103 | 224, 104 | ), 105 | ], 106 | [0.0, 0.5, 1.0], 107 | [ 108 | (1,), 109 | (3,), 110 | # (4,), 111 | # (None, ), 112 | ], 113 | ["CASCADE", "CONSUME"], 114 | ) 115 | ), 116 | ) 117 | def test_tx_on_torch_tensors( 118 | size: t.Tuple[int], 119 | p: float, 120 | channels: t.Tuple[int], 121 | tx_mode: str, 122 | ) -> None: 123 | 124 | if channels == (None,): 125 | channels = () 126 | 127 | size = channels + size 128 | 129 | tx = ptx.RandomInvert( 130 | p=p, 131 | tx_mode=tx_mode, 132 | ) 133 | 134 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size)) 135 | 136 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 137 | 138 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 139 | assert len(params_1) - len(orig_params) == tx.param_count 140 | 141 | orig_params = () 142 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 143 | aug_3, params_3 = tx.consume_transform(img, params_2) 144 | 145 | assert orig_params == params_3 146 | assert torch.all(torch.eq(aug_2, aug_3)) 147 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 148 | 149 | id_params = tx.get_default_params(img=img, processed=True) 150 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 151 | assert id_aug_rem_params == () 152 | assert torch.all(torch.eq(id_aug_img, img)) 153 | 154 | 155 | # Main. 156 | if __name__ == "__main__": 157 | pass 158 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_RandomPerspective.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | import torchvision.transforms.functional as tv_fn 23 | 24 | import parameterized_transforms.transforms as ptx 25 | 26 | import numpy as np 27 | 28 | import PIL.Image as Image 29 | import PIL.ImageChops as ImageChops 30 | 31 | from itertools import product 32 | 33 | import typing as t 34 | 35 | 36 | @pytest.mark.parametrize( 37 | "repetition, size, distortion_scale, p, interpolation_mode, fill, channels, tx_mode", 38 | list( 39 | product( 40 | [idx for idx in range(2)], 41 | [ 42 | ( 43 | 224, 44 | 224, 45 | ), 46 | ], 47 | [0.0, 0.5, 1.0], 48 | [0.0, 0.5, 1.0], 49 | [ 50 | tv_fn.InterpolationMode.NEAREST, 51 | tv_fn.InterpolationMode.BILINEAR, 52 | tv_fn.InterpolationMode.BICUBIC, 53 | # # The modes below do NOT work. 54 | # tv_fn.InterpolationMode.BOX, 55 | # tv_fn.InterpolationMode.HAMMING, 56 | # tv_fn.InterpolationMode.LANCZOS, 57 | ], 58 | [ 59 | 0, 60 | 243, 61 | (23, 54, 211), 62 | (213, 1, 2, 245), 63 | ], 64 | [ 65 | # (3,), # This channel does NOT work 66 | # (4,), # This channel does NOT work 67 | (None,), # Random non-reproducible error observed 68 | ], 69 | ["CASCADE", "CONSUME"], 70 | ) 71 | ), 72 | ) 73 | def test_tx_on_PIL_images( 74 | repetition: int, 75 | size: t.Tuple[int], 76 | distortion_scale: float, 77 | p: float, 78 | interpolation_mode: tv_fn.InterpolationMode, 79 | fill: t.Union[int, float, t.Sequence[int], t.Sequence[float]], 80 | channels: t.Tuple[int], 81 | tx_mode: str, 82 | ) -> None: 83 | 84 | try: 85 | 86 | if isinstance(fill, float) or isinstance(fill, int): 87 | clean_fill = fill 88 | elif len(fill) == channels[0]: 89 | clean_fill = fill 90 | else: 91 | clean_fill = 0 92 | 93 | if channels == (None,): 94 | channels = () 95 | 96 | size = size + channels 97 | 98 | tx = ptx.RandomPerspective( 99 | distortion_scale=distortion_scale, 100 | interpolation=interpolation_mode, 101 | fill=clean_fill, 102 | p=p, 103 | tx_mode=tx_mode, 104 | ) 105 | 106 | img = Image.fromarray( 107 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 108 | ) 109 | 110 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 111 | 112 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 113 | assert len(params_1) - len(orig_params) == tx.param_count 114 | 115 | orig_params = () 116 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 117 | aug_3, params_3 = tx.consume_transform(img, params_2) 118 | 119 | assert orig_params == params_3 120 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 121 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 122 | 123 | id_params = tx.get_default_params(img=img, processed=True) 124 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 125 | assert id_aug_rem_params == () 126 | assert not ImageChops.difference(id_aug_img, img).getbbox() 127 | 128 | except torch._C._LinAlgError as e: # noqa 129 | raise RuntimeError(f"ERROR | `torch._C._LinAlgError` raised in the test: {e}") 130 | 131 | except Exception as e: 132 | raise RuntimeError(f"ERROR | Issue raised in the test: {e}") 133 | 134 | 135 | @pytest.mark.parametrize( 136 | "repetition, size, distortion_scale, p, interpolation_mode, fill, channels, tx_mode", 137 | list( 138 | product( 139 | [idx for idx in range(2)], 140 | [ 141 | ( 142 | 224, 143 | 224, 144 | ), 145 | ], 146 | [0.0, 0.5, 1.0], 147 | [0.0, 0.5, 1.0], 148 | [ 149 | tv_fn.InterpolationMode.NEAREST, 150 | tv_fn.InterpolationMode.BILINEAR, 151 | # # The modes below are NOT supported. 152 | # tv_fn.InterpolationMode.BICUBIC, 153 | # tv_fn.InterpolationMode.BOX, 154 | # tv_fn.InterpolationMode.HAMMING, 155 | # tv_fn.InterpolationMode.LANCZOS, 156 | ], 157 | [ 158 | 0.0, 159 | 128.0 / 255, 160 | (23.0 / 255, 54.0 / 255, 211.0 / 255), 161 | (213.0 / 255, 1.0 / 255, 2.0 / 255, 245.0 / 255), 162 | ], 163 | [ 164 | (3,), 165 | (4,), 166 | # (None,), # Leads to errors in general. Avoiding. 167 | ], 168 | ["CASCADE", "CONSUME"], 169 | ) 170 | ), 171 | ) 172 | def test_tx_on_torch_tensors( 173 | repetition: int, 174 | size: t.Tuple[int], 175 | distortion_scale: float, 176 | p: float, 177 | interpolation_mode: tv_fn.InterpolationMode, 178 | fill: t.Union[int, float, t.Sequence[int], t.Sequence[float]], 179 | channels: t.Tuple[int], 180 | tx_mode: str, 181 | ) -> None: 182 | """ 183 | Observed issue-- 184 | ``` 185 | torch._C._LinAlgError: torch.linalg.lstsq: 186 | The least squares solution could not be computed because 187 | the input matrix does not have full rank (error code: 8). 188 | ``` 189 | Currently, the work-around is to ignore this error if it pops up. 190 | """ 191 | 192 | try: 193 | 194 | if isinstance(fill, float) or isinstance(fill, int): 195 | clean_fill = fill 196 | elif len(fill) == channels[0]: 197 | clean_fill = fill 198 | else: 199 | clean_fill = 0 200 | 201 | if channels == (None,): 202 | channels = () 203 | 204 | size = channels + size 205 | 206 | tx = ptx.RandomPerspective( 207 | distortion_scale=distortion_scale, 208 | interpolation=interpolation_mode, 209 | fill=clean_fill, 210 | p=p, 211 | tx_mode=tx_mode, 212 | ) 213 | 214 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size)) 215 | 216 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 217 | 218 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 219 | assert len(params_1) - len(orig_params) == tx.param_count 220 | 221 | orig_params = () 222 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 223 | aug_3, params_3 = tx.consume_transform(img, params_2) 224 | 225 | assert orig_params == params_3 226 | 227 | if interpolation_mode != tv_fn.InterpolationMode.BILINEAR: 228 | torch.all(torch.isclose(aug_2, aug_3, rtol=1e-02, atol=1e-05)) 229 | assert all( 230 | [isinstance(elt, float) or isinstance(elt, int) for elt in params_2] 231 | ) 232 | 233 | id_params = tx.get_default_params(img=img, processed=True) 234 | id_aug_img, id_aug_rem_params = tx.consume_transform( 235 | img=img, params=id_params 236 | ) 237 | assert id_aug_rem_params == () 238 | assert torch.all(torch.isclose(id_aug_img, img, rtol=1e-02, atol=1e-05)) 239 | 240 | except torch._C._LinAlgError as e: # noqa 241 | raise RuntimeError(f"ERROR | torch._C._LinAlgError raised in the test: {e}") 242 | 243 | except Exception as e: 244 | raise RuntimeError(f"ERROR | Issue raised in the test: {e}") 245 | 246 | 247 | # Main. 248 | if __name__ == "__main__": 249 | pass 250 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_RandomPosterize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | 25 | import numpy as np 26 | 27 | import PIL.Image as Image 28 | import PIL.ImageChops as ImageChops 29 | 30 | from itertools import product 31 | 32 | import typing as t 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "size, bits, p, channels, tx_mode", 37 | list( 38 | product( 39 | [ 40 | ( 41 | 224, 42 | 224, 43 | ), 44 | ], 45 | [num_bits for num_bits in range(9)], 46 | [0.0, 1.0, 0.5], 47 | [ 48 | (3,), 49 | # # Not supported. 50 | # (4,), 51 | # (None,), 52 | ], 53 | ["CASCADE", "CONSUME"], 54 | ) 55 | ), 56 | ) 57 | def test_tx_on_PIL_images( 58 | size: t.Tuple[int], 59 | bits: int, 60 | p: float, 61 | channels: t.Tuple[int], 62 | tx_mode: str, 63 | ) -> None: 64 | if channels == (None,): 65 | channels = () 66 | 67 | size = size + channels 68 | 69 | tx = ptx.RandomPosterize( 70 | bits=bits, 71 | p=p, 72 | tx_mode=tx_mode, 73 | ) 74 | 75 | img = Image.fromarray( 76 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 77 | ) 78 | 79 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 80 | 81 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 82 | assert len(params_1) - len(orig_params) == tx.param_count 83 | 84 | orig_params = () 85 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 86 | aug_3, params_3 = tx.consume_transform(img, params_2) 87 | 88 | assert orig_params == params_3 89 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 90 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 91 | 92 | id_params = tx.get_default_params(img=img, processed=True) 93 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 94 | assert id_aug_rem_params == () 95 | assert not ImageChops.difference(id_aug_img, img).getbbox() 96 | 97 | 98 | @pytest.mark.parametrize( 99 | "size, bits, p, channels, tx_mode", 100 | list( 101 | product( 102 | [ 103 | ( 104 | 224, 105 | 224, 106 | ), 107 | ], 108 | [num_bits for num_bits in range(9)], 109 | [0.0, 1.0, 0.5], 110 | [ 111 | (3,), 112 | # # Not supported. 113 | # (4,), 114 | # (None,), 115 | ], 116 | ["CASCADE", "CONSUME"], 117 | ) 118 | ), 119 | ) 120 | def test_tx_on_torch_tensors( 121 | size: t.Tuple[int], 122 | bits: int, 123 | p: float, 124 | channels: t.Tuple[int], 125 | tx_mode: str, 126 | ) -> None: 127 | if channels == (None,): 128 | channels = () 129 | 130 | size = channels + size 131 | 132 | tx = ptx.RandomPosterize( 133 | bits=bits, 134 | p=p, 135 | tx_mode=tx_mode, 136 | ) 137 | 138 | img = torch.from_numpy( 139 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 140 | ) 141 | 142 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 143 | 144 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 145 | assert len(params_1) - len(orig_params) == tx.param_count 146 | 147 | orig_params = () 148 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 149 | aug_3, params_3 = tx.consume_transform(img, params_2) 150 | 151 | assert orig_params == params_3 152 | assert torch.all(torch.eq(aug_2, aug_3)) 153 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 154 | 155 | id_params = tx.get_default_params(img=img, processed=True) 156 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 157 | assert id_aug_rem_params == () 158 | assert torch.all(torch.eq(id_aug_img, img)) 159 | 160 | 161 | # Main. 162 | if __name__ == "__main__": 163 | pass 164 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_RandomResizedCrop.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | import torchvision.transforms.functional as tv_fn 23 | 24 | import parameterized_transforms.transforms as ptx 25 | 26 | import numpy as np 27 | 28 | import PIL.Image as Image 29 | import PIL.ImageChops as ImageChops 30 | 31 | from itertools import product 32 | 33 | import typing as t 34 | 35 | 36 | @pytest.mark.parametrize( 37 | "size, to_size, scale, ratio, interpolation_mode, channels, tx_mode", 38 | list( 39 | product( 40 | [ 41 | ( 42 | 224, 43 | 224, 44 | ), 45 | ], 46 | [ 47 | 224, 48 | 32, 49 | [224, 224], 50 | [32, 32], 51 | (224, 224), 52 | (32, 32), 53 | [224, 32], 54 | (24, 224), 55 | ], 56 | [ 57 | (0.33, 0.67), 58 | [0.08, 1.0], 59 | ], 60 | [ 61 | (0.08, 1.0), 62 | [0.08, 1.0], 63 | ], 64 | [ 65 | tv_fn.InterpolationMode.NEAREST, 66 | tv_fn.InterpolationMode.BILINEAR, 67 | tv_fn.InterpolationMode.BICUBIC, 68 | tv_fn.InterpolationMode.BOX, 69 | tv_fn.InterpolationMode.HAMMING, 70 | tv_fn.InterpolationMode.LANCZOS, 71 | ], 72 | [(3,), (4,), (None,)], 73 | ["CASCADE", "CONSUME"], 74 | ) 75 | ), 76 | ) 77 | def test_tx_on_PIL_images( 78 | size: t.Tuple[int], 79 | to_size: t.Union[int, t.List[int], t.Tuple[int, int]], 80 | scale: t.Union[t.List[float], t.Tuple[float, float]], 81 | ratio: t.Union[t.List[float], t.Tuple[float, float]], 82 | interpolation_mode: tv_fn.InterpolationMode, 83 | channels: t.Tuple[int], 84 | tx_mode: str, 85 | ) -> None: 86 | if channels == (None,): 87 | channels = () 88 | 89 | size = size + channels 90 | 91 | tx = ptx.RandomResizedCrop( 92 | size=to_size, 93 | scale=scale, 94 | ratio=ratio, 95 | interpolation=interpolation_mode, 96 | tx_mode=tx_mode, 97 | ) 98 | 99 | img = Image.fromarray( 100 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 101 | ) 102 | 103 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 104 | 105 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 106 | assert len(params_1) - len(orig_params) == tx.param_count 107 | 108 | orig_params = () 109 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 110 | aug_3, params_3 = tx.consume_transform(img, params_2) 111 | 112 | assert orig_params == params_3 113 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 114 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 115 | 116 | id_params = tx.get_default_params(img=img, processed=True) 117 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 118 | if img.size == id_aug_img.size: 119 | assert id_aug_rem_params == () 120 | assert not ImageChops.difference(id_aug_img, img).getbbox() 121 | 122 | 123 | @pytest.mark.parametrize( 124 | "size, to_size, scale, ratio, interpolation_mode, channels, tx_mode", 125 | list( 126 | product( 127 | [ 128 | ( 129 | 224, 130 | 224, 131 | ), 132 | ], 133 | [ 134 | 224, 135 | 32, 136 | [224, 224], 137 | [32, 32], 138 | (224, 224), 139 | (32, 32), 140 | [224, 32], 141 | (24, 224), 142 | ], 143 | [ 144 | (0.33, 0.67), 145 | [0.08, 1.0], 146 | ], 147 | [ 148 | (0.08, 1.0), 149 | [0.08, 1.0], 150 | ], 151 | [ 152 | tv_fn.InterpolationMode.NEAREST, 153 | tv_fn.InterpolationMode.BILINEAR, 154 | tv_fn.InterpolationMode.BICUBIC, 155 | # # These modes below are NOT supported for tensors. 156 | # tv_fn.InterpolationMode.BOX, 157 | # tv_fn.InterpolationMode.HAMMING, 158 | # tv_fn.InterpolationMode.LANCZOS, 159 | ], 160 | [(3,), (4,)], 161 | ["CASCADE", "CONSUME"], 162 | ) 163 | ), 164 | ) 165 | def test_tx_on_torch_tensors( 166 | size: t.Tuple[int], 167 | to_size: t.Union[int, t.List[int], t.Tuple[int, int]], 168 | scale: t.Union[t.List[float], t.Tuple[float, float]], 169 | ratio: t.Union[t.List[float], t.Tuple[float, float]], 170 | interpolation_mode: tv_fn.InterpolationMode, 171 | channels: t.Tuple[int], 172 | tx_mode: str, 173 | ) -> None: 174 | if channels == (None,): 175 | channels = () 176 | 177 | size = channels + size 178 | 179 | tx = ptx.RandomResizedCrop( 180 | size=to_size, 181 | scale=scale, 182 | ratio=ratio, 183 | interpolation=interpolation_mode, 184 | tx_mode=tx_mode, 185 | ) 186 | 187 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size)) 188 | 189 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 190 | 191 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 192 | assert len(params_1) - len(orig_params) == tx.param_count 193 | 194 | orig_params = () 195 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 196 | aug_3, params_3 = tx.consume_transform(img, params_2) 197 | 198 | assert orig_params == params_3 199 | assert torch.all(torch.eq(aug_2, aug_3)) 200 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 201 | 202 | id_params = tx.get_default_params(img=img, processed=True) 203 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 204 | if img.shape == id_aug_img.shape: 205 | assert id_aug_rem_params == () 206 | assert torch.all(torch.eq(id_aug_img, img)) 207 | 208 | 209 | # Main. 210 | if __name__ == "__main__": 211 | pass 212 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_RandomRotation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | import torchvision.transforms.functional as tv_fn 23 | 24 | import parameterized_transforms.transforms as ptx 25 | 26 | import numpy as np 27 | 28 | import PIL.Image as Image 29 | import PIL.ImageChops as ImageChops 30 | 31 | from itertools import product 32 | 33 | import typing as t 34 | 35 | 36 | @pytest.mark.parametrize( 37 | "size, degrees, interpolation_mode, expand, fill, channels, tx_mode", 38 | list( 39 | product( 40 | [ 41 | ( 42 | 224, 43 | 224, 44 | ), 45 | ], 46 | [ 47 | 10, 48 | 0.0, 49 | [-10, 10.0], 50 | [3.14, 2 * 3.14], 51 | ], 52 | [ 53 | tv_fn.InterpolationMode.NEAREST, 54 | tv_fn.InterpolationMode.BILINEAR, 55 | tv_fn.InterpolationMode.BICUBIC, 56 | tv_fn.InterpolationMode.BOX, 57 | tv_fn.InterpolationMode.LANCZOS, 58 | tv_fn.InterpolationMode.HAMMING, 59 | ], 60 | [ 61 | True, 62 | False, 63 | ], 64 | [ 65 | 0, 66 | 243, 67 | (23, 54, 211), 68 | (213, 1, 2, 245), 69 | ], 70 | [(3,), (4,), (None,)], 71 | ["CASCADE", "CONSUME"], 72 | ) 73 | ), 74 | ) 75 | def test_tx_on_PIL_images( 76 | size: t.Tuple[int], 77 | degrees: t.Union[ 78 | float, 79 | int, 80 | t.List[float], 81 | t.List[int], 82 | t.Tuple[float, float], 83 | t.Tuple[int, int], 84 | ], 85 | interpolation_mode: tv_fn.InterpolationMode, 86 | expand: bool, 87 | fill: t.Union[int, float, t.Sequence[int], t.Sequence[float]], 88 | channels: t.Tuple[int], 89 | tx_mode: str, 90 | ) -> None: 91 | 92 | if isinstance(fill, float) or isinstance(fill, int): 93 | clean_fill = fill 94 | elif len(fill) == channels[0]: 95 | clean_fill = fill 96 | else: 97 | clean_fill = 0 98 | 99 | if channels == (None,): 100 | channels = () 101 | 102 | size = size + channels 103 | 104 | tx = ptx.RandomRotation( 105 | degrees=degrees, 106 | interpolation=interpolation_mode, 107 | expand=expand, 108 | center=None, 109 | fill=clean_fill, 110 | tx_mode=tx_mode, 111 | ) 112 | 113 | img = Image.fromarray( 114 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 115 | ) 116 | 117 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 118 | 119 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 120 | assert len(params_1) - len(orig_params) == tx.param_count 121 | 122 | orig_params = () 123 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 124 | aug_3, params_3 = tx.consume_transform(img, params_2) 125 | 126 | assert orig_params == params_3 127 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 128 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 129 | 130 | id_params = tx.get_default_params(img=img, processed=True) 131 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 132 | assert id_aug_rem_params == () 133 | assert not ImageChops.difference(id_aug_img, img).getbbox() 134 | 135 | 136 | @pytest.mark.parametrize( 137 | "size, degrees, interpolation_mode, expand, fill, channels, tx_mode", 138 | list( 139 | product( 140 | [ 141 | ( 142 | 224, 143 | 224, 144 | ), 145 | ], 146 | [ 147 | 10, 148 | 0.0, 149 | [-10, 10.0], 150 | [3.14, 2 * 3.14], 151 | # -1000., # This raises error; if single number, it MUST be positive 152 | ], 153 | [ 154 | tv_fn.InterpolationMode.NEAREST, 155 | tv_fn.InterpolationMode.BILINEAR, 156 | tv_fn.InterpolationMode.BICUBIC, 157 | tv_fn.InterpolationMode.BOX, 158 | tv_fn.InterpolationMode.LANCZOS, 159 | tv_fn.InterpolationMode.HAMMING, 160 | ], 161 | [ 162 | True, 163 | False, 164 | ], 165 | [ 166 | 0.0, 167 | 128.0 / 255, 168 | (23.0 / 255, 54.0 / 255, 211.0 / 255), 169 | (213.0 / 255, 1.0 / 255, 2.0 / 255, 245.0 / 255), 170 | ], 171 | [ 172 | (3,), 173 | (4,), 174 | ], 175 | ["CASCADE", "CONSUME"], 176 | ) 177 | ), 178 | ) 179 | def test_tx_on_torch_tensors( 180 | size: t.Tuple[int], 181 | degrees: t.Union[ 182 | float, 183 | int, 184 | t.List[float], 185 | t.List[int], 186 | t.Tuple[float, float], 187 | t.Tuple[int, int], 188 | ], 189 | interpolation_mode: tv_fn.InterpolationMode, 190 | expand: bool, 191 | fill: t.Union[int, float, t.Sequence[int], t.Sequence[float]], 192 | channels: t.Tuple[int], 193 | tx_mode: str, 194 | ) -> None: 195 | 196 | if isinstance(fill, float) or isinstance(fill, int): 197 | clean_fill = fill 198 | elif len(fill) == channels[0]: 199 | clean_fill = fill 200 | else: 201 | clean_fill = 0 202 | 203 | if channels == (None,): 204 | channels = () 205 | 206 | size = channels + size 207 | 208 | tx = ptx.RandomRotation( 209 | degrees=degrees, 210 | interpolation=interpolation_mode, 211 | expand=expand, 212 | center=None, 213 | fill=clean_fill, 214 | tx_mode=tx_mode, 215 | ) 216 | 217 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size)) 218 | 219 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 220 | 221 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 222 | assert len(params_1) - len(orig_params) == tx.param_count 223 | 224 | orig_params = () 225 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 226 | aug_3, params_3 = tx.consume_transform(img, params_2) 227 | 228 | assert orig_params == params_3 229 | assert torch.all(torch.eq(aug_2, aug_3)) 230 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 231 | 232 | id_params = tx.get_default_params(img=img, processed=True) 233 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 234 | assert id_aug_rem_params == () 235 | assert torch.all(torch.eq(id_aug_img, img)) 236 | 237 | 238 | # Main. 239 | if __name__ == "__main__": 240 | pass 241 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_RandomSolarize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | 25 | import numpy as np 26 | 27 | import PIL.Image as Image 28 | import PIL.ImageChops as ImageChops 29 | 30 | from itertools import product 31 | 32 | import typing as t 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "size, threshold, p, channels, tx_mode", 37 | list( 38 | product( 39 | [ 40 | ( 41 | 224, 42 | 224, 43 | ), 44 | ], 45 | [num_bits for num_bits in range(9)], 46 | [0, 0.5, 1.0], 47 | [ 48 | (3,), 49 | # # Not supported. 50 | # (4,), 51 | # (None,), 52 | ], 53 | ["CASCADE", "CONSUME"], 54 | ) 55 | ), 56 | ) 57 | def test_tx_on_PIL_images( 58 | size: t.Tuple[int], 59 | threshold: int, 60 | p: float, 61 | channels: t.Tuple[int], 62 | tx_mode: str, 63 | ) -> None: 64 | if channels == (None,): 65 | channels = () 66 | 67 | size = size + channels 68 | 69 | tx = ptx.RandomSolarize( 70 | threshold=threshold, 71 | p=p, 72 | tx_mode=tx_mode, 73 | ) 74 | 75 | img = Image.fromarray( 76 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 77 | ) 78 | 79 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 80 | 81 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 82 | assert len(params_1) - len(orig_params) == tx.param_count 83 | 84 | orig_params = () 85 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 86 | aug_3, params_3 = tx.consume_transform(img, params_2) 87 | 88 | assert orig_params == params_3 89 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 90 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 91 | 92 | id_params = tx.get_default_params(img=img, processed=True) 93 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 94 | assert id_aug_rem_params == () 95 | assert not ImageChops.difference(id_aug_img, img).getbbox() 96 | 97 | 98 | @pytest.mark.parametrize( 99 | "size, threshold, p, channels, tx_mode", 100 | list( 101 | product( 102 | [ 103 | ( 104 | 224, 105 | 224, 106 | ), 107 | ( 108 | 32, 109 | 32, 110 | ), 111 | ( 112 | 28, 113 | 28, 114 | ), 115 | ], 116 | [num_bits for num_bits in range(9)], 117 | # Even weird values work! 118 | [0, 0.5, 1.0, 255,], 119 | [ 120 | (3,), 121 | # # Not supported. 122 | # (4,), 123 | # (None,), 124 | ], 125 | ["CASCADE", "CONSUME"], 126 | ) 127 | ), 128 | ) 129 | def test_tx_on_torch_tensors( 130 | size: t.Tuple[int], 131 | threshold: float, 132 | p: float, 133 | channels: t.Tuple[int], 134 | tx_mode: str, 135 | ) -> None: 136 | if channels == (None,): 137 | channels = () 138 | 139 | size = channels + size 140 | 141 | tx = ptx.RandomSolarize( 142 | threshold=threshold, 143 | p=p, 144 | tx_mode=tx_mode, 145 | ) 146 | 147 | img = torch.from_numpy( 148 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 149 | ) 150 | 151 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 152 | 153 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 154 | assert len(params_1) - len(orig_params) == tx.param_count 155 | 156 | orig_params = () 157 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 158 | aug_3, params_3 = tx.consume_transform(img, params_2) 159 | 160 | assert orig_params == params_3 161 | assert torch.all(torch.eq(aug_2, aug_3)) 162 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 163 | 164 | id_params = tx.get_default_params(img=img, processed=True) 165 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 166 | assert id_aug_rem_params == () 167 | assert torch.all(torch.eq(id_aug_img, img)) 168 | 169 | 170 | # Main. 171 | if __name__ == "__main__": 172 | pass 173 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_RandomVerticalFlip.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | 25 | import numpy as np 26 | 27 | import PIL.Image as Image 28 | import PIL.ImageChops as ImageChops 29 | 30 | from itertools import product 31 | 32 | import typing as t 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "size, p, channels, tx_mode", 37 | list( 38 | product( 39 | [ 40 | ( 41 | 224, 42 | 224, 43 | ), 44 | ], 45 | [0.0, 0.5, 1.0], 46 | [(3,), (4,), (None,)], 47 | ["CASCADE", "CONSUME"], 48 | ) 49 | ), 50 | ) 51 | def test_tx_on_PIL_images( 52 | size: t.Tuple[int], 53 | p: float, 54 | channels: t.Tuple[int], 55 | tx_mode: str, 56 | ) -> None: 57 | 58 | # 59 | class DummyContext(object): 60 | def __enter__(self): 61 | pass 62 | 63 | def __exit__(self, exc_type, exc_val, exc_tb): 64 | pass 65 | 66 | if channels == (None,): 67 | channels = () 68 | 69 | size = size + channels 70 | 71 | tx = ptx.RandomVerticalFlip( 72 | p=p, 73 | tx_mode=tx_mode, 74 | ) 75 | 76 | img = Image.fromarray( 77 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 78 | ) 79 | 80 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 81 | 82 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 83 | assert len(params_1) - len(orig_params) == tx.param_count 84 | 85 | orig_params = () 86 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 87 | aug_3, params_3 = tx.consume_transform(img, params_2) 88 | 89 | assert orig_params == params_3 90 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 91 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 92 | 93 | if params_2[-1] == 0: 94 | assert not ImageChops.difference(aug_2, img).getbbox() 95 | 96 | id_params = tx.get_default_params(img=img, processed=True) 97 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 98 | assert id_aug_rem_params == () 99 | assert not ImageChops.difference(id_aug_img, img).getbbox() 100 | 101 | 102 | @pytest.mark.parametrize( 103 | "size, p, channels, tx_mode", 104 | list( 105 | product( 106 | [ 107 | ( 108 | 224, 109 | 224, 110 | ), 111 | ], 112 | [0.0, 0.5, 1.0], 113 | [(3,), (4,), (None,)], 114 | ["CASCADE", "CONSUME"], 115 | ) 116 | ), 117 | ) 118 | def test_tx_on_torch_tensors( 119 | size: t.Tuple[int], 120 | p: float, 121 | channels: t.Tuple[int], 122 | tx_mode: str, 123 | ) -> None: 124 | 125 | # 126 | class DummyContext(object): 127 | def __enter__(self): 128 | pass 129 | 130 | def __exit__(self, exc_type, exc_val, exc_tb): 131 | pass 132 | 133 | if channels == (None,): 134 | channels = () 135 | 136 | size = channels + size 137 | 138 | tx = ptx.RandomVerticalFlip( 139 | p=p, 140 | tx_mode=tx_mode, 141 | ) 142 | 143 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size)) 144 | 145 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 146 | 147 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 148 | assert len(params_1) - len(orig_params) == tx.param_count 149 | 150 | orig_params = () 151 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 152 | aug_3, params_3 = tx.consume_transform(img, params_2) 153 | 154 | assert orig_params == params_3 155 | assert torch.all(torch.eq(aug_2, aug_3)) 156 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 157 | 158 | if params_2[-1] == 0: 159 | assert torch.all(torch.eq(aug_2, img)) 160 | 161 | id_params = tx.get_default_params(img=img, processed=True) 162 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 163 | assert id_aug_rem_params == () 164 | assert torch.all(torch.eq(id_aug_img, img)) 165 | 166 | 167 | # Main. 168 | if __name__ == "__main__": 169 | pass 170 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_Resize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | import torchvision.transforms.functional as tv_fn 23 | 24 | import parameterized_transforms.transforms as ptx 25 | 26 | import numpy as np 27 | 28 | import PIL.Image as Image 29 | import PIL.ImageChops as ImageChops 30 | 31 | from itertools import product 32 | 33 | import typing as t 34 | 35 | 36 | @pytest.mark.parametrize( 37 | "size, to_size, interpolation_mode, channels, tx_mode", 38 | list( 39 | product( 40 | [ 41 | ( 42 | 224, 43 | 224, 44 | ), 45 | ], 46 | [ 47 | [224, 224], 48 | [32, 32], 49 | (224, 224), 50 | (32, 32), 51 | (224, 32), 52 | [32, 28], 53 | (224,), 54 | (32,), 55 | [ 56 | 224, 57 | ], 58 | [ 59 | 32, 60 | ], 61 | 224, 62 | 32, 63 | ], 64 | [ 65 | tv_fn.InterpolationMode.NEAREST, 66 | tv_fn.InterpolationMode.BILINEAR, 67 | tv_fn.InterpolationMode.BICUBIC, 68 | tv_fn.InterpolationMode.BOX, 69 | tv_fn.InterpolationMode.LANCZOS, 70 | tv_fn.InterpolationMode.HAMMING, 71 | ], 72 | [ 73 | (None,), 74 | (3,), 75 | (4,), 76 | ], 77 | ["CASCADE", "CONSUME"], 78 | ) 79 | ), 80 | ) 81 | def test_tx_on_PIL_images( 82 | size: t.Tuple[int], 83 | to_size: t.Union[t.Tuple[int], t.List[int], int], 84 | interpolation_mode: tv_fn.InterpolationMode, 85 | channels: t.Tuple[int], 86 | tx_mode: str, 87 | ) -> None: 88 | 89 | if channels == (None,): 90 | channels = () 91 | 92 | size = size + channels 93 | 94 | tx = ptx.Resize( 95 | size=to_size, 96 | interpolation=interpolation_mode, 97 | max_size=None, 98 | antialias=None, 99 | tx_mode=tx_mode, 100 | ) 101 | 102 | img = Image.fromarray( 103 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 104 | ) 105 | 106 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 107 | 108 | aug_1, params_1 = tx(img, orig_params) 109 | 110 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 111 | 112 | aug_3, params_3 = tx.consume_transform(img, orig_params) 113 | 114 | assert orig_params == params_1 115 | assert orig_params == params_2 116 | assert orig_params == params_3 117 | 118 | assert not ImageChops.difference(aug_1, aug_2).getbbox() 119 | assert not ImageChops.difference(aug_1, aug_3).getbbox() 120 | 121 | id_params = tx.get_default_params(img=img, processed=True) 122 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 123 | if img.size == id_aug_img.size: 124 | assert id_aug_rem_params == () 125 | assert not ImageChops.difference(id_aug_img, img).getbbox() 126 | 127 | 128 | @pytest.mark.parametrize( 129 | "size, to_size, interpolation_mode, channels, tx_mode", 130 | list( 131 | product( 132 | [ 133 | ( 134 | 224, 135 | 224, 136 | ), 137 | ], 138 | [ 139 | [224, 224], 140 | [32, 32], 141 | (224, 224), 142 | (32, 32), 143 | (224, 32), 144 | [32, 28], 145 | (224,), 146 | (32,), 147 | [ 148 | 224, 149 | ], 150 | [ 151 | 32, 152 | ], 153 | 224, 154 | 32, 155 | ], 156 | [ 157 | tv_fn.InterpolationMode.NEAREST, 158 | tv_fn.InterpolationMode.BILINEAR, 159 | tv_fn.InterpolationMode.BICUBIC, 160 | # # The below interpolation modes do not work well. 161 | # tv_fn.InterpolationMode.BOX, 162 | # tv_fn.InterpolationMode.LANCZOS, 163 | # tv_fn.InterpolationMode.HAMMING, 164 | ], 165 | [ 166 | (3,), 167 | (4,), 168 | # (None, ) # This raises errors 169 | ], 170 | ["CASCADE", "CONSUME"], 171 | ) 172 | ), 173 | ) 174 | def test_tx_on_torch_tensors( 175 | size: t.Tuple[int], 176 | to_size: t.Union[t.Tuple[int], t.List[int], int], 177 | interpolation_mode: tv_fn.InterpolationMode, 178 | channels: t.Tuple[int], 179 | tx_mode: str, 180 | ) -> None: 181 | class DummyContext(object): 182 | def __enter__(self): 183 | pass 184 | 185 | def __exit__(self, exc_type, exc_val, exc_tb): 186 | pass 187 | 188 | if channels == (None,): 189 | channels = () 190 | 191 | size = channels + size 192 | 193 | tx = ptx.Resize( 194 | size=to_size, 195 | interpolation=interpolation_mode, 196 | max_size=None, 197 | antialias=None, 198 | tx_mode=tx_mode, 199 | ) 200 | 201 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size)) 202 | 203 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 204 | 205 | if interpolation_mode in [ 206 | tv_fn.InterpolationMode.BOX, 207 | tv_fn.InterpolationMode.LANCZOS, 208 | tv_fn.InterpolationMode.HAMMING, 209 | ]: 210 | context = pytest.raises(ValueError) 211 | else: 212 | context = DummyContext() 213 | 214 | with context: 215 | 216 | aug_1, params_1 = tx(img, orig_params) 217 | 218 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 219 | 220 | aug_3, params_3 = tx.consume_transform(img, orig_params) 221 | 222 | assert orig_params == params_1 223 | assert orig_params == params_2 224 | assert orig_params == params_3 225 | 226 | assert torch.all(torch.eq(aug_1, aug_2)) 227 | assert torch.all(torch.eq(aug_1, aug_3)) 228 | 229 | id_params = tx.get_default_params(img=img, processed=True) 230 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 231 | if img.shape == id_aug_img.shape: 232 | assert id_aug_rem_params == () 233 | assert torch.all(torch.eq(id_aug_img, img)) 234 | 235 | 236 | # Main. 237 | if __name__ == "__main__": 238 | pass 239 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_TenCrop.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | 25 | import numpy as np 26 | 27 | import PIL.Image as Image 28 | import PIL.ImageChops as ImageChops 29 | 30 | from itertools import product 31 | 32 | import typing as t 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "size, to_size, channels, tx_mode", 37 | list( 38 | product( 39 | [ 40 | ( 41 | 224, 42 | 224, 43 | ), 44 | ], 45 | [ 46 | # 224, # This gives error as crop size > image size is possible 47 | 28, 48 | [24, 24], 49 | (24, 24), 50 | [24, 3], 51 | [3, 8], 52 | (8, 24), 53 | ], 54 | [(3,), (4,), (None,)], 55 | ["CASCADE", "CONSUME"], 56 | ) 57 | ), 58 | ) 59 | def test_tx_on_PIL_images( 60 | size: t.Tuple[int], 61 | to_size: t.Union[int, t.List[int], t.Tuple[int, int]], 62 | channels: t.Tuple[int], 63 | tx_mode: str, 64 | ) -> None: 65 | if channels == (None,): 66 | channels = () 67 | 68 | size = size + channels 69 | 70 | tx = ptx.TenCrop( 71 | size=to_size, 72 | tx_mode=tx_mode, 73 | ) 74 | 75 | img = Image.fromarray( 76 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 77 | ) 78 | 79 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 80 | 81 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 82 | assert len(params_1) - len(orig_params) == tx.param_count 83 | 84 | orig_params = () 85 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 86 | aug_3, params_3 = tx.consume_transform(img, params_2) 87 | 88 | assert orig_params == params_3 89 | assert all( 90 | [ 91 | not ImageChops.difference(aug_2_component, aug_3_component).getbbox() 92 | for aug_2_component, aug_3_component in zip(aug_2, aug_3) 93 | ] 94 | ) 95 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 96 | 97 | id_params = tx.get_default_params(img=img, processed=True) 98 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 99 | if img.size == id_aug_img[0].size: 100 | assert all( 101 | not ImageChops.difference(img, an_id_aug_img).getbbox() 102 | for an_id_aug_img in id_aug_img[:5] 103 | ) 104 | assert id_aug_rem_params == () 105 | 106 | 107 | @pytest.mark.parametrize( 108 | "size, to_size, channels, tx_mode", 109 | list( 110 | product( 111 | [ 112 | ( 113 | 224, 114 | 224, 115 | ), 116 | ], 117 | [ 118 | # 224, # This gives error as crop size > image size is possible 119 | 28, 120 | [24, 24], 121 | (24, 24), 122 | [24, 3], 123 | [3, 8], 124 | (8, 24), 125 | ], 126 | [(3,), (4,), (None,)], 127 | ["CASCADE", "CONSUME"], 128 | ) 129 | ), 130 | ) 131 | def test_tx_on_torch_tensors( 132 | size: t.Tuple[int], 133 | to_size: t.Union[int, t.List[int], t.Tuple[int, int]], 134 | channels: t.Tuple[int], 135 | tx_mode: str, 136 | ) -> None: 137 | if channels == (None,): 138 | channels = () 139 | 140 | size = channels + size 141 | 142 | tx = ptx.TenCrop( 143 | size=to_size, 144 | tx_mode=tx_mode, 145 | ) 146 | 147 | img = torch.from_numpy(np.random.uniform(low=0, high=1.0, size=size)) 148 | 149 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 150 | 151 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 152 | assert len(params_1) - len(orig_params) == tx.param_count 153 | 154 | orig_params = () 155 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 156 | aug_3, params_3 = tx.consume_transform(img, params_2) 157 | 158 | assert orig_params == params_3 159 | assert all( 160 | [ 161 | torch.all(torch.eq(aug_2_component, aug_3_component)) 162 | for aug_2_component, aug_3_component in zip(aug_2, aug_3) 163 | ] 164 | ) 165 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 166 | 167 | id_params = tx.get_default_params(img=img, processed=True) 168 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 169 | if img.shape == id_aug_img[0].shape: 170 | assert all( 171 | torch.all(torch.eq(img, an_id_aug_img)) for an_id_aug_img in id_aug_img[:5] 172 | ) 173 | assert id_aug_rem_params == () 174 | 175 | 176 | # Main. 177 | if __name__ == "__main__": 178 | pass 179 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_ToPILImage.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | # Dependencies. 18 | import pytest 19 | import torch 20 | 21 | import parameterized_transforms.transforms as ptx 22 | 23 | import numpy as np 24 | 25 | import PIL.Image as Image 26 | import PIL.ImageChops as ImageChops 27 | 28 | from itertools import product 29 | 30 | import typing as t 31 | 32 | 33 | @pytest.mark.parametrize( 34 | "size, channels, tx_mode", 35 | list( 36 | product( 37 | [ 38 | ( 39 | 224, 40 | 224, 41 | ), 42 | ( 43 | 32, 44 | 32, 45 | ), 46 | ( 47 | 28, 48 | 28, 49 | ), 50 | ], 51 | [ 52 | (None,), 53 | (1,), 54 | (3,), 55 | (4,), 56 | ], 57 | ["CASCADE", "CONSUME"], 58 | ) 59 | ), 60 | ) 61 | def test_tx_on_PIL_images( 62 | size: t.Tuple[int], 63 | channels: t.Tuple[int], 64 | tx_mode: str, 65 | ) -> None: 66 | with pytest.raises(TypeError): 67 | if channels == (None,): 68 | channels = () 69 | 70 | size = size + channels 71 | 72 | tx = ptx.ToPILImage() 73 | 74 | img = Image.fromarray( 75 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 76 | ) 77 | 78 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 79 | 80 | aug_1, params_1 = tx(img, orig_params) 81 | 82 | 83 | @pytest.mark.parametrize( 84 | "size, channels, channel_first, high, high_type, tx_mode", 85 | list( 86 | product( 87 | [ 88 | ( 89 | 224, 90 | 224, 91 | ), 92 | ], 93 | [ 94 | (None,), 95 | (1,), 96 | (3,), 97 | (4,), 98 | ], 99 | [False], 100 | [1.0, 256], 101 | [ 102 | np.uint8, 103 | # np.int, np.int8, np.int16, np.int32, np.int64 # This does not work 104 | ], 105 | ["CASCADE", "CONSUME"], 106 | ) 107 | ), 108 | ) 109 | def test_tx_on_ndarrays( 110 | size: t.Tuple[int], 111 | channels: t.Tuple[int], 112 | channel_first: bool, 113 | high: t.Union, 114 | high_type: t.Any, 115 | tx_mode: str, 116 | ) -> None: 117 | if channels == (None,): 118 | channels = () 119 | 120 | if channel_first: 121 | size = channels + size 122 | else: 123 | size = size + channels 124 | 125 | tx = ptx.ToPILImage() 126 | 127 | if high == 1.0: 128 | img = np.random.uniform(low=0, high=1.0, size=size).astype(high_type) 129 | else: 130 | img = np.random.randint(low=0, high=256, size=size).astype(np.uint8) 131 | 132 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 133 | 134 | aug_1, params_1 = tx(img, orig_params) 135 | 136 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 137 | 138 | aug_3, params_3 = tx.consume_transform(img, orig_params) 139 | 140 | assert orig_params == params_1 141 | assert orig_params == params_2 142 | assert orig_params == params_3 143 | 144 | assert not ImageChops.difference(aug_1, aug_2).getbbox() 145 | assert not ImageChops.difference(aug_1, aug_3).getbbox() 146 | 147 | 148 | @pytest.mark.parametrize( 149 | "size, channels, channel_first, high, high_type, tx_mode", 150 | list( 151 | product( 152 | [ 153 | ( 154 | 224, 155 | 224, 156 | ), 157 | ], 158 | [ 159 | (None,), 160 | (1,), 161 | (3,), 162 | (4,), 163 | ], 164 | [True], 165 | [1.0, 256], 166 | [ 167 | np.uint8, 168 | # np.int, np.int8, np.int16, np.int32, np.int64 # This does not work 169 | ], 170 | ["CASCADE", "CONSUME"], 171 | ) 172 | ), 173 | ) 174 | def test_tx_on_torch_tensors( 175 | size: t.Tuple[int], 176 | channels: t.Tuple[int], 177 | channel_first: bool, 178 | high: t.Union, 179 | high_type: t.Any, 180 | tx_mode: str, 181 | ) -> None: 182 | if channels == (None,): 183 | channels = () 184 | 185 | if channel_first: 186 | size = channels + size 187 | else: 188 | size = size + channels 189 | 190 | tx = ptx.ToPILImage() 191 | 192 | if high == 1.0: 193 | img = torch.from_numpy( 194 | np.random.uniform(low=0, high=1.0, size=size).astype(high_type) 195 | ) 196 | else: 197 | img = torch.from_numpy( 198 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 199 | ) 200 | 201 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 202 | 203 | aug_1, params_1 = tx(img, orig_params) 204 | 205 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 206 | 207 | aug_3, params_3 = tx.consume_transform(img, orig_params) 208 | 209 | assert orig_params == params_1 210 | assert orig_params == params_2 211 | assert orig_params == params_3 212 | 213 | assert not ImageChops.difference(aug_1, aug_2).getbbox() 214 | assert not ImageChops.difference(aug_1, aug_3).getbbox() 215 | 216 | 217 | # Main. 218 | if __name__ == "__main__": 219 | pass 220 | -------------------------------------------------------------------------------- /tests/test_atomic_transforms/test_ToTensor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | # Dependencies. 18 | import pytest 19 | import torch 20 | 21 | import parameterized_transforms.transforms as ptx 22 | 23 | import numpy as np 24 | 25 | import PIL.Image as Image 26 | 27 | from itertools import product 28 | 29 | import typing as t 30 | 31 | 32 | @pytest.mark.parametrize( 33 | "size, channels, tx_mode", 34 | list( 35 | product( 36 | [ 37 | ( 38 | 224, 39 | 224, 40 | ), 41 | ], 42 | [ 43 | (None,), 44 | (3,), 45 | (4,), 46 | ], 47 | ["CASCADE", "CONSUME"], 48 | ) 49 | ), 50 | ) 51 | def test_tx_on_PIL_images_1( 52 | size: t.Tuple[int], 53 | channels: t.Tuple[int], 54 | tx_mode: str, 55 | ) -> None: 56 | if channels == (None,): 57 | channels = () 58 | 59 | size = size + channels 60 | 61 | tx = ptx.ToTensor() 62 | 63 | img = Image.fromarray( 64 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 65 | ) 66 | 67 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 68 | 69 | aug_1, params_1 = tx(img, orig_params) 70 | 71 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 72 | 73 | aug_3, params_3 = tx.consume_transform(img, orig_params) 74 | 75 | assert len(list(aug_1.shape)) == 3 # [C, H, W] format ensured 76 | assert aug_1.shape[0] == 1 or aug_1.shape[0] == 3 or aug_1.shape[0] == 4 77 | assert 0 <= torch.min(aug_1).item() <= 1.0 78 | assert 0 <= torch.max(aug_1).item() <= 1.0 79 | 80 | assert orig_params == params_1 81 | assert orig_params == params_2 82 | assert orig_params == params_3 83 | 84 | assert torch.all(torch.eq(aug_1, aug_2)) 85 | assert torch.all(torch.eq(aug_1, aug_3)) 86 | 87 | 88 | @pytest.mark.parametrize( 89 | "size, channels, tx_mode", 90 | list( 91 | product( 92 | [ 93 | ( 94 | 224, 95 | 224, 96 | ), 97 | ], 98 | [ 99 | (1,), 100 | ], 101 | ["CASCADE", "CONSUME"], 102 | ) 103 | ), 104 | ) 105 | def test_tx_on_PIL_images_2( 106 | size: t.Tuple[int], 107 | channels: t.Tuple[int], 108 | tx_mode: str, 109 | ) -> None: 110 | with pytest.raises(TypeError): 111 | if channels == (None,): 112 | channels = () 113 | 114 | size = size + channels 115 | 116 | tx = ptx.ToTensor() 117 | 118 | img = Image.fromarray( 119 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 120 | ) 121 | 122 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 123 | 124 | aug_1, params_1 = tx(img, orig_params) 125 | 126 | 127 | @pytest.mark.parametrize( 128 | "size, channels, channel_first, high, high_type, tx_mode", 129 | list( 130 | product( 131 | [ 132 | ( 133 | 224, 134 | 224, 135 | ), 136 | ], 137 | [ 138 | (None,), 139 | (1,), 140 | (3,), 141 | (4,), 142 | ], 143 | [True, False], 144 | [1.0, 256], 145 | [ 146 | np.float16, 147 | np.float32, 148 | np.float64, 149 | np.uint8, 150 | np.int8, 151 | np.int16, 152 | np.int32, 153 | np.int64, 154 | ], 155 | ["CASCADE", "CONSUME"], 156 | ) 157 | ), 158 | ) 159 | def test_tx_on_ndarrays( 160 | size: t.Tuple[int], 161 | channels: t.Tuple[int], 162 | channel_first: bool, 163 | high: t.Union, 164 | high_type: t.Any, 165 | tx_mode: str, 166 | ) -> None: 167 | if channels == (None,): 168 | channels = () 169 | 170 | if channel_first: 171 | size = channels + size 172 | else: 173 | size = size + channels 174 | 175 | tx = ptx.ToTensor() 176 | 177 | if high == 1.0: 178 | img = np.random.uniform(low=0, high=1.0, size=size).astype(high_type) 179 | else: 180 | img = np.random.randint(low=0, high=256, size=size).astype(np.uint8) 181 | 182 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 183 | 184 | aug_1, params_1 = tx(img, orig_params) 185 | 186 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 187 | 188 | aug_3, params_3 = tx.consume_transform(img, orig_params) 189 | 190 | assert len(list(aug_1.shape)) == 3 # [C, H, W] format ensured 191 | assert 0 <= torch.min(aug_1).item() <= high 192 | assert 0 <= torch.max(aug_1).item() <= high 193 | 194 | assert orig_params == params_1 195 | assert orig_params == params_2 196 | assert orig_params == params_3 197 | 198 | assert torch.all(torch.eq(aug_1, aug_2)) 199 | assert torch.all(torch.eq(aug_1, aug_3)) 200 | 201 | 202 | @pytest.mark.parametrize( 203 | "size, channels, channel_first, high, high_type, tx_mode", 204 | list( 205 | product( 206 | [ 207 | ( 208 | 224, 209 | 224, 210 | ), 211 | ], 212 | [ 213 | (None,), 214 | (1,), 215 | (3,), 216 | (4,), 217 | ], 218 | [True, False], 219 | [1.0, 256], 220 | [ 221 | np.float16, 222 | np.float32, 223 | np.float64, 224 | np.uint8, 225 | np.int8, 226 | np.int16, 227 | np.int32, 228 | np.int64, 229 | ], 230 | ["CASCADE", "CONSUME"], 231 | ) 232 | ), 233 | ) 234 | def test_tx_on_torch_tensors( 235 | size: t.Tuple[int], 236 | channels: t.Tuple[int], 237 | channel_first: bool, 238 | high: t.Union, 239 | high_type: t.Any, 240 | tx_mode: str, 241 | ) -> None: 242 | with pytest.raises(TypeError): 243 | 244 | if channels == (None,): 245 | channels = () 246 | 247 | if channel_first: 248 | size = channels + size 249 | else: 250 | size = size + channels 251 | 252 | tx = ptx.ToTensor() 253 | 254 | if high == 1.0: 255 | img = torch.from_numpy( 256 | np.random.uniform(low=0, high=1.0, size=size).astype(high_type) 257 | ) 258 | if high == 256: 259 | img = torch.from_numpy( 260 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 261 | ) 262 | 263 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 264 | 265 | aug_1, params_1 = tx(img, orig_params) 266 | 267 | 268 | # Main. 269 | if __name__ == "__main__": 270 | pass 271 | -------------------------------------------------------------------------------- /tests/test_composing_transforms/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | 20 | 21 | # Main. 22 | if __name__ == "__main__": 23 | pass 24 | -------------------------------------------------------------------------------- /tests/test_composing_transforms/test_RandomChoice.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | 25 | import numpy as np 26 | 27 | import PIL.Image as Image 28 | import PIL.ImageChops as ImageChops 29 | 30 | from itertools import product 31 | 32 | import typing as t 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "size, p, repeat_count, core_txs, channels, tx_mode", 37 | list( 38 | product( 39 | [ 40 | ( 41 | 224, 42 | 224, 43 | ), 44 | ], 45 | # p 46 | [None, "random", "equal"], 47 | # repeat count to try out a large number of orderings 48 | [idx for idx in range(5)], 49 | [ 50 | [ 51 | ptx.RandomHorizontalFlip(p=0.5), 52 | ptx.ColorJitter( 53 | brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1 54 | ), 55 | ptx.RandomGrayscale(p=0.5), 56 | ], 57 | [ 58 | ptx.RandomHorizontalFlip(p=0.5), 59 | ptx.ColorJitter( 60 | brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2 61 | ), 62 | ptx.RandomGrayscale(p=0.5), 63 | ], 64 | ], 65 | [ 66 | (3,), 67 | # (4,), 68 | (None,), 69 | ], 70 | ["CASCADE", "CONSUME"], 71 | ) 72 | ), 73 | ) 74 | def test_tx_on_PIL_images( 75 | size: t.Tuple[int], 76 | p: t.Any, 77 | repeat_count: int, 78 | core_txs: t.List[t.Callable], 79 | channels: t.Tuple[int], 80 | tx_mode: str, 81 | ) -> None: 82 | 83 | # Clean probabilities. 84 | if p == "random": 85 | p = np.random.uniform( 86 | low=0, 87 | high=1.0, 88 | size=[ 89 | len(core_txs), 90 | ], 91 | ) 92 | p /= np.sum(p) 93 | elif p == "equal": 94 | p = np.ones( 95 | shape=[ 96 | len(core_txs), 97 | ] 98 | ) 99 | p /= np.sum(p) 100 | elif p is None: 101 | pass 102 | else: 103 | raise NotImplementedError( 104 | "ERROR | Ensure `p` is `None` or `equal` or `random`." 105 | ) 106 | 107 | if channels == (None,): 108 | channels = () 109 | 110 | size = size + channels 111 | 112 | tx = ptx.RandomChoice(transforms=core_txs, p=p, tx_mode=tx_mode) 113 | 114 | img = Image.fromarray( 115 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 116 | ) 117 | 118 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 119 | 120 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 121 | assert len(params_1) - len(orig_params) == tx.param_count 122 | 123 | orig_params = () 124 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 125 | aug_3, params_3 = tx.consume_transform(img, params_2) 126 | 127 | assert orig_params == params_3 128 | if isinstance(aug_2, torch.Tensor): 129 | assert torch.all(torch.eq(aug_2, aug_3)) 130 | else: 131 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 132 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 133 | 134 | id_params = tx.get_default_params(img=img, processed=True) 135 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 136 | assert id_aug_rem_params == () 137 | assert not ImageChops.difference(id_aug_img, img).getbbox() 138 | 139 | 140 | @pytest.mark.parametrize( 141 | "size, p, repeat_count, core_txs, channels, tx_mode", 142 | list( 143 | product( 144 | [ 145 | ( 146 | 224, 147 | 224, 148 | ), 149 | ], 150 | # p 151 | [None, "random", "equal"], 152 | # repeat count to try out a large number of orderings 153 | [idx for idx in range(5)], 154 | [ 155 | [ 156 | ptx.RandomHorizontalFlip(p=0.5), 157 | ptx.ColorJitter( 158 | brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1 159 | ), 160 | ptx.RandomGrayscale(p=0.5), 161 | ], 162 | [ 163 | ptx.RandomHorizontalFlip(p=0.5), 164 | ptx.ColorJitter( 165 | brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2 166 | ), 167 | ptx.RandomGrayscale(p=0.5), 168 | ], 169 | ], 170 | [ 171 | (3,), 172 | ], 173 | ["CASCADE", "CONSUME"], 174 | ) 175 | ), 176 | ) 177 | def test_tx_on_torch_tensors( 178 | size: t.Tuple[int], 179 | repeat_count: int, 180 | p: t.Any, 181 | core_txs: t.List[t.Callable], 182 | channels: t.Tuple[int], 183 | tx_mode: str, 184 | ) -> None: 185 | 186 | # Clean probabilities. 187 | if p == "random": 188 | p = np.random.uniform( 189 | low=0, 190 | high=1.0, 191 | size=[ 192 | len(core_txs), 193 | ], 194 | ) 195 | p /= np.sum(p) 196 | elif p == "equal": 197 | p = np.ones( 198 | shape=[ 199 | len(core_txs), 200 | ] 201 | ) 202 | p /= np.sum(p) 203 | elif p is None: 204 | pass 205 | else: 206 | raise NotImplementedError( 207 | "ERROR | Ensure `p` is `None` or `equal` or `random`." 208 | ) 209 | 210 | if channels == (None,): 211 | channels = () 212 | 213 | size = channels + size 214 | 215 | tx = ptx.RandomChoice(transforms=core_txs, p=p, tx_mode=tx_mode) 216 | 217 | img = torch.from_numpy( 218 | np.random.uniform(low=0.0, high=1.0, size=size).astype(np.uint8) 219 | ) 220 | 221 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 222 | 223 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 224 | assert len(params_1) - len(orig_params) == tx.param_count 225 | 226 | orig_params = () 227 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 228 | aug_3, params_3 = tx.consume_transform(img, params_2) 229 | 230 | assert orig_params == params_3 231 | if isinstance(aug_2, torch.Tensor): 232 | assert torch.all(torch.eq(aug_2, aug_3)) 233 | else: 234 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 235 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 236 | 237 | id_params = tx.get_default_params(img=img, processed=True) 238 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 239 | assert id_aug_rem_params == () 240 | assert torch.all(torch.eq(id_aug_img, img)) 241 | 242 | 243 | # Main. 244 | if __name__ == "__main__": 245 | pass 246 | -------------------------------------------------------------------------------- /tests/test_composing_transforms/test_RandomOrder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | # Dependencies. 19 | import pytest 20 | 21 | import torch 22 | 23 | import parameterized_transforms.transforms as ptx 24 | 25 | import numpy as np 26 | 27 | import PIL.Image as Image 28 | import PIL.ImageChops as ImageChops 29 | 30 | from itertools import product 31 | 32 | import typing as t 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "size, repeat_count, core_txs, channels, tx_mode", 37 | list( 38 | product( 39 | [ 40 | ( 41 | 224, 42 | 224, 43 | ), 44 | ], 45 | # repeat count to try out a large number of orderings 46 | [idx for idx in range(5)], 47 | [ 48 | [ 49 | ptx.RandomHorizontalFlip(p=0.5), 50 | ptx.ColorJitter( 51 | brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1 52 | ), 53 | ptx.RandomGrayscale(p=0.5), 54 | ], 55 | [ 56 | ptx.RandomHorizontalFlip(p=0.5), 57 | ptx.ColorJitter( 58 | brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2 59 | ), 60 | ptx.RandomGrayscale(p=0.5), 61 | ], 62 | ], 63 | [ 64 | (3,), 65 | # (4,), 66 | (None,), 67 | ], 68 | ["CASCADE", "CONSUME"], 69 | ) 70 | ), 71 | ) 72 | def test_tx_on_PIL_images( 73 | size: t.Tuple[int], 74 | repeat_count: int, 75 | core_txs: t.List[t.Callable], 76 | channels: t.Tuple[int], 77 | tx_mode: str, 78 | ) -> None: 79 | 80 | if channels == (None,): 81 | channels = () 82 | 83 | size = size + channels 84 | 85 | tx = ptx.RandomOrder(transforms=core_txs, tx_mode=tx_mode) 86 | 87 | img = Image.fromarray( 88 | np.random.randint(low=0, high=256, size=size).astype(np.uint8) 89 | ) 90 | 91 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 92 | 93 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 94 | assert len(params_1) - len(orig_params) == tx.param_count 95 | 96 | orig_params = () 97 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 98 | aug_3, params_3 = tx.consume_transform(img, params_2) 99 | 100 | assert orig_params == params_3 101 | if isinstance(aug_2, torch.Tensor): 102 | assert torch.all(torch.eq(aug_2, aug_3)) 103 | else: 104 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 105 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 106 | 107 | id_params = tx.get_default_params(img=img, processed=True) 108 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 109 | assert id_aug_rem_params == () 110 | assert not ImageChops.difference(id_aug_img, img).getbbox() 111 | 112 | 113 | @pytest.mark.parametrize( 114 | "size, repeat_count, core_txs, channels, tx_mode", 115 | list( 116 | product( 117 | [ 118 | ( 119 | 224, 120 | 224, 121 | ), 122 | ], 123 | # repeat count to try out a large number of orderings 124 | [idx for idx in range(5)], 125 | [ 126 | [ 127 | ptx.RandomHorizontalFlip(p=0.5), 128 | ptx.ColorJitter( 129 | brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1 130 | ), 131 | ptx.RandomGrayscale(p=0.5), 132 | ], 133 | [ 134 | ptx.RandomHorizontalFlip(p=0.5), 135 | ptx.ColorJitter( 136 | brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2 137 | ), 138 | ptx.RandomGrayscale(p=0.5), 139 | ], 140 | ], 141 | [ 142 | (3,), 143 | ], 144 | ["CASCADE", "CONSUME"], 145 | ) 146 | ), 147 | ) 148 | def test_tx_on_torch_tensors( 149 | size: t.Tuple[int], 150 | repeat_count: int, 151 | core_txs: t.List[t.Callable], 152 | channels: t.Tuple[int], 153 | tx_mode: str, 154 | ) -> None: 155 | 156 | if channels == (None,): 157 | channels = () 158 | 159 | size = channels + size 160 | 161 | tx = ptx.RandomOrder(transforms=core_txs, tx_mode=tx_mode) 162 | 163 | img = torch.from_numpy( 164 | np.random.uniform(low=0.0, high=1.0, size=size).astype(np.uint8) 165 | ) 166 | 167 | orig_params = tuple([3.0, 4.0, -1.0, 0.0]) 168 | 169 | aug_1, params_1 = tx.cascade_transform(img, orig_params) 170 | assert len(params_1) - len(orig_params) == tx.param_count 171 | 172 | orig_params = () 173 | aug_2, params_2 = tx.cascade_transform(img, orig_params) 174 | aug_3, params_3 = tx.consume_transform(img, params_2) 175 | 176 | assert orig_params == params_3 177 | if isinstance(aug_2, torch.Tensor): 178 | assert torch.all(torch.eq(aug_2, aug_3)) 179 | else: 180 | assert not ImageChops.difference(aug_2, aug_3).getbbox() 181 | assert all([isinstance(elt, float) or isinstance(elt, int) for elt in params_2]) 182 | 183 | id_params = tx.get_default_params(img=img, processed=True) 184 | id_aug_img, id_aug_rem_params = tx.consume_transform(img=img, params=id_params) 185 | assert id_aug_rem_params == () 186 | assert torch.all(torch.eq(id_aug_img, img)) 187 | 188 | 189 | # Main. 190 | if __name__ == "__main__": 191 | pass 192 | -------------------------------------------------------------------------------- /tests/test_functions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Apple Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | import parameterized_transforms.core as ptc 19 | import parameterized_transforms.utils as ptu 20 | 21 | import numpy as np 22 | 23 | 24 | def test_get_total_params_count() -> None: 25 | class Object: 26 | def __init__(self, param_count: int) -> None: 27 | self.param_count = param_count 28 | 29 | def __str__(self) -> str: 30 | return self.__repr__() 31 | 32 | def __repr__(self) -> str: 33 | return f"{self.__class__.__name__}" f"(param_count={self.param_count})" 34 | 35 | get_total_params_count = ptu.get_total_params_count 36 | 37 | for NUM_OBJ in [10, 100]: 38 | 39 | obj_list = [ 40 | Object(param_count=np.random.randint(0, 10)) for _ in range(NUM_OBJ) 41 | ] 42 | # print(obj_list) 43 | 44 | obj_tuple = tuple( 45 | Object(param_count=np.random.randint(0, 10)) for _ in range(NUM_OBJ) 46 | ) 47 | # print(obj_tuple) 48 | 49 | obj_dict = dict( 50 | [("", Object(param_count=np.random.randint(0, 10))) for _ in range(NUM_OBJ)] 51 | ) 52 | # print(obj_dict) 53 | 54 | assert get_total_params_count(obj_list) == sum( 55 | [obj.param_count for obj in obj_list] 56 | ) 57 | assert get_total_params_count(obj_tuple) == sum( 58 | [obj.param_count for obj in obj_tuple] 59 | ) 60 | assert get_total_params_count(obj_dict) == sum( 61 | [obj.param_count for obj_name, obj in obj_dict.items()] 62 | ) 63 | 64 | 65 | def test_concat_params() -> None: 66 | 67 | concat_params = ptc.Transform.concat_params 68 | 69 | obj_tuple = [ 70 | (), 71 | ( 72 | 1, 73 | 2, 74 | 3, 75 | ), 76 | (4,), 77 | (), 78 | (), 79 | (5, 6), 80 | ( 81 | 7, 82 | 8, 83 | ), 84 | ] 85 | assert concat_params(*obj_tuple) == (1, 2, 3, 4, 5, 6, 7, 8) 86 | 87 | params_1, params_2, params_3, params_4 = ( 88 | (), 89 | (1, 2), 90 | ( 91 | 3, 92 | 4, 93 | 5, 94 | ), 95 | (), 96 | ) 97 | assert concat_params(params_1, params_2, params_3, params_4) == (1, 2, 3, 4, 5) 98 | 99 | 100 | # Main. 101 | if __name__ == "__main__": 102 | pass 103 | --------------------------------------------------------------------------------