├── .cursor └── rules │ ├── albumentations-rules.mdc │ ├── coding-guidelines.mdc │ └── mosaic-rules.mdc ├── .gitattributes ├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── bug-report.md │ ├── code-quality.md │ ├── community.md │ ├── documentation.md │ ├── feature-request.md │ ├── localization.md │ ├── performance.md │ ├── question.md │ └── ui.md └── workflows │ ├── ci.yml │ ├── codeflash.yml │ └── upload_to_pypi.yml ├── .gitignore ├── .markdownlint.json ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MAINTAINERS.md ├── MANIFEST.in ├── README.md ├── albumentations ├── __init__.py ├── augmentations │ ├── __init__.py │ ├── blur │ │ ├── __init__.py │ │ ├── functional.py │ │ └── transforms.py │ ├── crops │ │ ├── __init__.py │ │ ├── functional.py │ │ └── transforms.py │ ├── dropout │ │ ├── __init__.py │ │ ├── channel_dropout.py │ │ ├── coarse_dropout.py │ │ ├── functional.py │ │ ├── grid_dropout.py │ │ ├── mask_dropout.py │ │ ├── transforms.py │ │ └── xy_masking.py │ ├── geometric │ │ ├── __init__.py │ │ ├── distortion.py │ │ ├── flip.py │ │ ├── functional.py │ │ ├── pad.py │ │ ├── resize.py │ │ ├── rotate.py │ │ └── transforms.py │ ├── mixing │ │ ├── __init__.py │ │ ├── domain_adaptation.py │ │ ├── domain_adaptation_functional.py │ │ ├── functional.py │ │ └── transforms.py │ ├── other │ │ ├── __init__.py │ │ ├── lambda_transform.py │ │ └── type_transform.py │ ├── pixel │ │ ├── __init__.py │ │ ├── functional.py │ │ └── transforms.py │ ├── spectrogram │ │ ├── __init__.py │ │ └── transform.py │ ├── text │ │ ├── __init__.py │ │ ├── functional.py │ │ └── transforms.py │ ├── transforms3d │ │ ├── __init__.py │ │ ├── functional.py │ │ └── transforms.py │ └── utils.py ├── check_version.py ├── core │ ├── __init__.py │ ├── bbox_utils.py │ ├── composition.py │ ├── hub_mixin.py │ ├── keypoints_utils.py │ ├── label_manager.py │ ├── pydantic.py │ ├── serialization.py │ ├── transforms_interface.py │ ├── type_definitions.py │ ├── utils.py │ └── validation.py └── pytorch │ ├── __init__.py │ └── transforms.py ├── conda.recipe ├── build_upload.sh ├── conda_build_config.yaml └── meta.yaml ├── docs └── contributing │ ├── coding_guidelines.md │ └── environment_setup.md ├── pyproject.toml ├── requirements-dev.txt ├── setup.py ├── tests ├── __init__.py ├── aug_definitions.py ├── conftest.py ├── files │ ├── LiberationSerif-Bold.ttf │ ├── transform_serialization_v2_with_totensor.json │ └── transform_serialization_v2_without_totensor.json ├── functional │ ├── test_affine.py │ ├── test_blur.py │ ├── test_dropout.py │ ├── test_functional.py │ ├── test_geometric.py │ └── test_mixing.py ├── test_augmentations.py ├── test_bbox.py ├── test_blur.py ├── test_check_version.py ├── test_compose_operators.py ├── test_core.py ├── test_core_utils.py ├── test_crop.py ├── test_domain_adaptation.py ├── test_hub_mixin.py ├── test_keypoint.py ├── test_mixing.py ├── test_mosaic.py ├── test_other.py ├── test_pydantic.py ├── test_pytorch.py ├── test_resize_area_downscale.py ├── test_serialization.py ├── test_targets.py ├── test_text.py ├── test_transforms.py ├── transforms3d │ ├── test_functions.py │ ├── test_pytorch.py │ ├── test_targets.py │ └── test_transforms.py └── utils.py └── tools ├── check_albucore_version.py ├── check_defaults.py ├── check_docstrings.py ├── check_example_docstrings.py ├── check_naming_conflicts.py ├── check_no_defaults_in_schemas.py └── make_transforms_docs.py /.cursor/rules/albumentations-rules.mdc: -------------------------------------------------------------------------------- 1 | --- 2 | description: 3 | globs: 4 | alwaysApply: true 5 | --- 6 | 7 | # Your rule content 8 | 9 | - We use python 3.10+ typing. I.e. not Tuple, but tuple, not List, but list, not Optional, but | None 10 | - get_params_dependent_on_data should look minimal, but should look small and clear as we just call other functions from it 11 | - we do not use fill_value, but fill. Not fill_mask_value, but fill_mask 12 | - We do not have ANY default values in the InitSchema class 13 | - Use pytest.mark.parametrize for parameterized tests 14 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.py linguist-language=python 2 | *.ipynb linguist-documentation 3 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: albumentations-team # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: ['https://albumentations.ai/support/', 'https://www.paypal.com/paypalme/ternaus'] # Albumentations support page and PayPal.me 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: 'bug' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Describe the bug 11 | 12 | A clear and concise description of what the bug is. 13 | 14 | ### To Reproduce 15 | 16 | Steps to reproduce the behavior: 17 | 18 | 1. Environment (e.g., OS, Python version, Albumentations version, etc.) 19 | 2. Sample code that produces the bug. 20 | 3. Any error messages or incorrect outputs. 21 | 22 | ### Expected behavior 23 | 24 | A clear and concise description of what you expected to happen. 25 | 26 | ### Actual behavior 27 | 28 | Describe what actually happened, including how it differs from your expectations. 29 | 30 | ### Screenshots 31 | 32 | If applicable, add screenshots to help explain your problem. 33 | 34 | ### Additional context 35 | 36 | Add any other context about the problem here, like links to similar issues or possible solutions you've found. 37 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/code-quality.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Code Quality Improvement 3 | about: Suggestions for improving code quality, including refactoring and adherence to coding standards 4 | title: '' 5 | labels: 'tech debt' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Suggested Improvement 11 | 12 | Describe the code quality improvement you suggest. This could include refactoring, adherence to coding standards, reducing complexity, etc. 13 | 14 | ## Potential Benefits 15 | 16 | Explain the benefits of your suggested improvement, such as increased maintainability, reduced technical debt, or improved performance. 17 | 18 | ## Additional Information 19 | 20 | Provide any additional information or context that could help understand your suggestion, including code snippets or links to best practices. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/community.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Community and Contributions 3 | about: Questions or discussions about contributing to the project 4 | title: '' 5 | labels: 'community' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Your Inquiry 11 | 12 | Describe your question or the discussion topic related to contributing to the project. 13 | 14 | ### Contribution Proposal 15 | 16 | If you have a specific contribution in mind, please describe it here. Include how you believe it will benefit the project. 17 | 18 | ### Seeking Guidance 19 | 20 | If you are looking for guidance on how to start contributing or on specific contribution practices, please detail your needs here. 21 | 22 | ### Additional Context 23 | 24 | Provide any additional context that might help in fostering a productive discussion about community and contributions. 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Documentation Request 3 | about: Suggest improvements or request additional documentation 4 | title: '' 5 | labels: 'documentation' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Documentation request 11 | 12 | Briefly describe the documentation improvement or addition you're suggesting. Clearly state if it's a new document, an addition to existing documentation, or a correction. 13 | 14 | ### Motivation 15 | 16 | Explain why this documentation is necessary. Include any specific challenges or confusion you encountered due to a lack of clear documentation. Mention how this improvement can benefit other users. 17 | 18 | ### Suggested location 19 | 20 | Where do you think this documentation should be located? For example, README.md, the project's official documentation site, docstrings, etc. 21 | 22 | ### Additional context 23 | 24 | Provide any additional information that might help understand your request better, such as links to related issues, discussions, or external resources. If you have seen similar documentation in other projects that you found helpful, feel free to share those examples here. 25 | 26 | ### Would you be willing to contribute to the documentation? 27 | 28 | Let us know if you are interested in helping write this documentation. Your contribution can speed up the process and benefit the entire community. 29 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: 'enhancement' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Feature description 11 | 12 | Briefly describe the feature you'd like to see added to Albumentations. Explain the problem it solves or the value it adds to the library. 13 | 14 | ### Motivation and context 15 | 16 | Explain why this feature is important and how it fits into the overall objectives of the Albumentations library. Include any relevant links or examples that might help clarify the feature request. 17 | 18 | ### Possible implementation 19 | 20 | Describe how you envision this feature being implemented, if you have ideas. Include considerations for backward compatibility, performance, and how it would integrate with existing functionalities. 21 | 22 | ### Alternatives 23 | 24 | Have you considered any alternative solutions or features? If so, describe them and explain why they were not suitable. 25 | 26 | ### Additional context 27 | 28 | Add any other context, screenshots, sketches, or code snippets about the feature request here. This can include use cases, benchmarks, or other information that would help the development team understand your request better. 29 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/localization.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Localization and Internationalization 3 | about: Request or suggest adding support for additional languages 4 | title: '' 5 | labels: 'localization' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Language Requested 11 | 12 | Specify the language(s) you are requesting support for. 13 | 14 | ### Reason for Request 15 | 16 | Explain why this language support is important for the project. Include any relevant demographic or user base information. 17 | 18 | ### Contribution 19 | 20 | If you are able to contribute to the localization effort, please let us know here. 21 | 22 | ### Additional Context 23 | 24 | Any other information that might help in understanding the importance of this request. 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/performance.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Performance Issue 3 | about: Report performance issues like slow execution or high memory consumption 4 | title: '' 5 | labels: 'performance' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Describe the performance issue 11 | 12 | A clear and concise description of what the issue is, including specific metrics or observations. 13 | 14 | ## Reproduction Steps 15 | 16 | Please provide a minimal, reproducible example or a sequence of steps that demonstrates the performance issue. 17 | 18 | ## Expected vs. Actual Behavior 19 | 20 | Describe what you expected to happen and how it differs from what you are actually experiencing. 21 | 22 | ## Environment 23 | 24 | - OS: 25 | - Python version: 26 | - Albumentations version (or other relevant software versions): 27 | 28 | ## Additional Context 29 | 30 | Any additional information that could help understand and address the issue, such as profiling outputs or benchmarks. 31 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question or Clarification 3 | about: Ask a question or seek clarification about the project 4 | title: '' 5 | labels: 'question' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Your Question 11 | 12 | Describe your question or the clarification you are seeking in detail. 13 | 14 | ### What have you tried? 15 | 16 | Briefly explain what you have tried so far to find an answer to your question. This could include documentation you have read or experiments you have conducted. 17 | 18 | ### Additional Context 19 | 20 | Provide any additional context or screenshots that might help in answering your question. 21 | 22 | ### Relevant documentation or external resources 23 | 24 | If you have consulted any specific documentation or external resources, please list them here. 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/ui.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: UI/UX Enhancement 3 | about: Suggest improvements to the user interface and user experience 4 | title: '' 5 | labels: 'ui/ux' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Enhancement Description 11 | 12 | Describe the UI/UX enhancement you suggest. Be as detailed as possible. 13 | 14 | ### Motivation and Goals 15 | 16 | Explain why this enhancement is needed and what goals it aims to achieve. Include any specific problems it solves or improvements it brings. 17 | 18 | ### Possible Implementation 19 | 20 | If you have ideas about how to implement this enhancement, share them here. Include design sketches or mockups if available. 21 | 22 | ### Additional Context 23 | 24 | Provide any additional context or examples of similar enhancements in other projects that might help illustrate your suggestion. 25 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | pull_request: 4 | branches: 5 | - main 6 | paths-ignore: 7 | - '**.md' 8 | - 'benchmark/**' 9 | 10 | jobs: 11 | test_and_lint: 12 | name: Test and lint 13 | runs-on: ${{ matrix.operating-system }} 14 | strategy: 15 | matrix: 16 | operating-system: [ubuntu-latest, windows-latest, macos-latest] 17 | python-version: [3.9, "3.10", "3.11", "3.12"] 18 | include: 19 | - operating-system: ubuntu-latest 20 | path: ~/.cache/pip 21 | - operating-system: windows-latest 22 | path: ~\AppData\Local\pip\Cache 23 | - operating-system: macos-latest 24 | path: ~/Library/Caches/pip 25 | fail-fast: true 26 | steps: 27 | - name: Checkout 28 | uses: actions/checkout@v4 29 | - name: Set up Python 30 | uses: actions/setup-python@v5 31 | with: 32 | python-version: ${{ matrix.python-version }} 33 | 34 | - name: Cache Python packages 35 | uses: actions/cache@v4 36 | with: 37 | path: ${{ matrix.path }} 38 | key: ${{ runner.os }}-uv-${{ hashFiles('**/requirements-dev.txt') }} 39 | restore-keys: | 40 | ${{ runner.os }}-uv-${{ matrix.python-version }}- 41 | ${{ runner.os }}-uv- 42 | 43 | - name: Install uv 44 | run: python -m pip install --upgrade uv 45 | 46 | - name: Install PyTorch 47 | run: | 48 | if [ "${{ matrix.operating-system }}" = "macos-latest" ]; then 49 | uv pip install --system torch==2.6.0 torchvision==0.21.0 50 | else 51 | uv pip install --system torch==2.6.0+cpu torchvision==0.21.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu 52 | fi 53 | shell: bash 54 | 55 | - name: Install all dependencies 56 | run: | 57 | uv pip install --system wheel 58 | uv pip install --system -r requirements-dev.txt 59 | uv pip install --system . 60 | 61 | - name: Run PyTest 62 | run: pytest 63 | 64 | check_code_formatting_types: 65 | name: Check code formatting with ruff and mypy 66 | runs-on: ubuntu-latest 67 | strategy: 68 | fail-fast: true 69 | matrix: 70 | python-version: ["3.9"] 71 | steps: 72 | - name: Checkout 73 | uses: actions/checkout@v4 74 | - name: Set up Python 75 | uses: actions/setup-python@v5 76 | with: 77 | python-version: ${{ matrix.python-version }} 78 | 79 | - name: Install all requirements 80 | run: | 81 | python -m pip install --upgrade uv 82 | uv pip install --system -r requirements-dev.txt 83 | uv pip install --system . 84 | 85 | - name: Run checks 86 | run: pre-commit run --files $(find albumentations -type f) 87 | - name: check-defaults-in-apply 88 | run: python -m tools.check_defaults 89 | 90 | check_transforms_docs: 91 | name: Check Readme is not outdated 92 | runs-on: ubuntu-latest 93 | strategy: 94 | matrix: 95 | python-version: [3.9] 96 | steps: 97 | - name: Checkout 98 | uses: actions/checkout@v4 99 | - name: Set up Python 100 | uses: actions/setup-python@v5 101 | with: 102 | python-version: ${{ matrix.python-version }} 103 | - name: Install all requirements 104 | run: | 105 | python -m pip install --upgrade uv 106 | uv pip install --system requests 107 | uv pip install --system . 108 | - name: Run checks for documentation 109 | run: python -m tools.make_transforms_docs check README.md 110 | -------------------------------------------------------------------------------- /.github/workflows/codeflash.yml: -------------------------------------------------------------------------------- 1 | name: Codeflash 2 | 3 | on: 4 | pull_request: 5 | paths: 6 | # So that this workflow only runs when code within the target module is modified 7 | - 'albumentations/**' 8 | workflow_dispatch: 9 | 10 | concurrency: 11 | # Any new push to the PR will cancel the previous run, so that only the latest code is optimized 12 | group: ${{ github.workflow }}-${{ github.ref }} 13 | cancel-in-progress: true 14 | 15 | 16 | jobs: 17 | optimize: 18 | name: Optimize new Python code in this PR 19 | # Don't run codeflash on codeflash-ai[bot] commits, prevent duplicate optimizations 20 | if: ${{ github.actor != 'codeflash-ai[bot]' }} 21 | runs-on: ubuntu-latest 22 | env: 23 | CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }} 24 | CODEFLASH_PR_NUMBER: ${{ github.event.number }} 25 | 26 | steps: 27 | - uses: actions/checkout@v4 28 | with: 29 | fetch-depth: 0 30 | - name: "Set up Python" 31 | uses: actions/setup-python@v5 32 | with: 33 | python-version: '3.12.9' 34 | - name: Install uv 35 | uses: astral-sh/setup-uv@v5 36 | 37 | - name: Install Dependencies 38 | run: uv sync 39 | 40 | - name: Install additional Dev dependencies 41 | run: uv pip install -r requirements-dev.txt 42 | 43 | - name: Install codeflash and ruff 44 | run: uv pip install codeflash ruff 45 | 46 | - name: run codeflash 47 | run: uv run python -m codeflash.main 48 | -------------------------------------------------------------------------------- /.github/workflows/upload_to_pypi.yml: -------------------------------------------------------------------------------- 1 | name: Upload release to PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | upload: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Set up Python 13 | uses: actions/setup-python@v5 14 | with: 15 | python-version: '3.9' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install build twine 20 | - name: Build and publish 21 | env: 22 | TWINE_USERNAME: __token__ 23 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 24 | run: | 25 | python -m build 26 | twine upload dist/* 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | ### JetBrains template 108 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 109 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 110 | 111 | # User-specific stuff: 112 | .idea/**/workspace.xml 113 | .idea/**/tasks.xml 114 | .idea/dictionaries 115 | 116 | # Sensitive or high-churn files: 117 | .idea/**/dataSources/ 118 | .idea/**/dataSources.ids 119 | .idea/**/dataSources.local.xml 120 | .idea/**/sqlDataSources.xml 121 | .idea/**/dynamic.xml 122 | .idea/**/uiDesigner.xml 123 | 124 | # Gradle: 125 | .idea/**/gradle.xml 126 | .idea/**/libraries 127 | 128 | # CMake 129 | cmake-build-debug/ 130 | cmake-build-release/ 131 | 132 | # Mongo Explorer plugin: 133 | .idea/**/mongoSettings.xml 134 | 135 | ## File-based project format: 136 | *.iws 137 | 138 | ## Plugin-specific files: 139 | 140 | # IntelliJ 141 | out/ 142 | 143 | # mpeltonen/sbt-idea plugin 144 | .idea_modules/ 145 | 146 | # JIRA plugin 147 | atlassian-ide-plugin.xml 148 | 149 | # Cursive Clojure plugin 150 | .idea/replstate.xml 151 | 152 | # Crashlytics plugin (for Android Studio and IntelliJ) 153 | com_crashlytics_export_strings.xml 154 | crashlytics.properties 155 | crashlytics-build.properties 156 | fabric.properties 157 | 158 | .idea 159 | 160 | conda_build/ 161 | 162 | .vscode/ 163 | 164 | *.ipynb 165 | 166 | .ruff_cache/ 167 | 168 | data/ 169 | 170 | notebooks/ 171 | -------------------------------------------------------------------------------- /.markdownlint.json: -------------------------------------------------------------------------------- 1 | { 2 | "default": true, 3 | "MD013": false, 4 | "MD033": false, 5 | "MD045": false 6 | } 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autofix_commit_msg: | 3 | [pre-commit.ci] auto fixes from pre-commit.com hooks 4 | 5 | for more information, see https://pre-commit.ci 6 | autofix_prs: true 7 | autoupdate_branch: '' 8 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate' 9 | autoupdate_schedule: weekly 10 | skip: [ ] 11 | submodules: false 12 | 13 | repos: 14 | - repo: https://github.com/pre-commit/pre-commit-hooks 15 | rev: v5.0.0 16 | hooks: 17 | - id: check-added-large-files 18 | - id: check-ast 19 | - id: check-builtin-literals 20 | - id: check-case-conflict 21 | - id: check-docstring-first 22 | - id: check-executables-have-shebangs 23 | - id: check-shebang-scripts-are-executable 24 | - id: check-symlinks 25 | - id: check-toml 26 | - id: check-xml 27 | - id: detect-private-key 28 | - id: forbid-new-submodules 29 | - id: forbid-submodules 30 | - id: mixed-line-ending 31 | - id: destroyed-symlinks 32 | - id: fix-byte-order-marker 33 | - id: check-json 34 | - id: debug-statements 35 | - id: end-of-file-fixer 36 | - id: trailing-whitespace 37 | - id: requirements-txt-fixer 38 | - repo: local 39 | hooks: 40 | - id: check-docstrings 41 | name: Check Docstrings for '---' sequences 42 | entry: python tools/check_docstrings.py 43 | language: python 44 | types: [python] 45 | - id: check-naming-conflicts 46 | name: Check for naming conflicts between modules and functions/classes 47 | entry: python -m tools.check_naming_conflicts 48 | language: python 49 | pass_filenames: false 50 | - id: check-example-docstrings 51 | name: Check for 'Examples' sections in transform classes (must be plural form) 52 | entry: python tools/check_example_docstrings.py 53 | language: python 54 | types: [python] 55 | files: ^albumentations/ 56 | pass_filenames: true 57 | additional_dependencies: ["google-docstring-parser>=0.0.7"] 58 | - id: check-no-defaults-in-schemas 59 | name: Check no defaults in BaseModel schemas 60 | entry: python tools/check_no_defaults_in_schemas.py 61 | language: python 62 | types: [python] 63 | files: ^albumentations/ 64 | pass_filenames: true 65 | - repo: local 66 | hooks: 67 | - id: check-albucore-version 68 | name: Check albucore version 69 | entry: python ./tools/check_albucore_version.py 70 | language: system 71 | files: setup.py 72 | - repo: https://github.com/astral-sh/ruff-pre-commit 73 | rev: v0.11.12 74 | hooks: 75 | - id: ruff 76 | exclude: '__pycache__/' 77 | args: [ --fix ] 78 | - id: ruff-format 79 | - repo: https://github.com/pre-commit/pygrep-hooks 80 | rev: v1.10.0 81 | hooks: 82 | - id: python-check-mock-methods 83 | - id: python-use-type-annotations 84 | - id: python-check-blanket-noqa 85 | - id: python-use-type-annotations 86 | - id: text-unicode-replacement-char 87 | - repo: https://github.com/codespell-project/codespell 88 | rev: v2.4.1 89 | hooks: 90 | - id: codespell 91 | additional_dependencies: ["tomli"] 92 | - repo: https://github.com/igorshubovych/markdownlint-cli 93 | rev: v0.45.0 94 | hooks: 95 | - id: markdownlint 96 | - repo: https://github.com/tox-dev/pyproject-fmt 97 | rev: "v2.6.0" 98 | hooks: 99 | - id: pyproject-fmt 100 | - repo: https://github.com/pre-commit/mirrors-mypy 101 | rev: v1.15.0 102 | hooks: 103 | - id: mypy 104 | files: ^albumentations/ 105 | additional_dependencies: [ types-PyYAML, types-setuptools, pydantic>=2.9] 106 | args: 107 | [ --config-file=pyproject.toml ] 108 | - repo: https://github.com/ternaus/google-docstring-parser 109 | rev: 0.0.8 # Use the latest version 110 | hooks: 111 | - id: check-google-docstrings 112 | files: ^albumentations/ 113 | exclude: ^build/ 114 | additional_dependencies: ["tomli>=2.0.0"] 115 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | configuration: docs/conf.py 5 | 6 | formats: 7 | - htmlzip 8 | - pdf 9 | - epub 10 | 11 | python: 12 | version: 3.9 13 | system_packages: true 14 | install: 15 | - requirements: docs/requirements.txt 16 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in the Albumentations project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | ## Our Standards 8 | 9 | Examples of behavior that contributes to creating a positive environment include: 10 | 11 | * Using welcoming and inclusive language 12 | * Being respectful of differing viewpoints and experiences 13 | * Gracefully accepting constructive criticism 14 | * Focusing on what is best for the community 15 | * Showing empathy towards other community members 16 | 17 | Examples of unacceptable behavior by participants include: 18 | 19 | * The use of sexualized language or imagery and unwelcome sexual attention or advances 20 | * Trolling, insulting/derogatory comments, and personal or political attacks 21 | * Public or private harassment 22 | * Publishing others' private information, such as physical or electronic addresses, without explicit permission 23 | * Other conduct which could reasonably be considered inappropriate in a professional setting 24 | 25 | ## Our Responsibilities 26 | 27 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 28 | 29 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 30 | 31 | ## Scope 32 | 33 | This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 34 | 35 | ## Enforcement 36 | 37 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at [iglovikov@gmail.com](mailto:iglovikov@gmail.com). All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 38 | 39 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. 40 | 41 | ## Attribution 42 | 43 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 44 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) 45 | 46 | [homepage]: https://www.contributor-covenant.org 47 | 48 | For answers to common questions about this code of conduct, see 49 | [https://www.contributor-covenant.org/faq](https://www.contributor-covenant.org/faq) 50 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Albumentations 2 | 3 | Thank you for your interest in contributing to [Albumentations](https://albumentations.ai/)! This guide will help you get started with contributing to our image augmentation library. 4 | 5 | ## Quick Start 6 | 7 | For small changes (e.g., bug fixes), feel free to submit a PR directly. 8 | 9 | For larger changes: 10 | 11 | 1. Create an [issue](https://github.com/albumentations-team/albumentations/issues) outlining your proposed change 12 | 2. Join our [Discord community](https://discord.gg/e6zHCXTvaN) to discuss your idea 13 | 14 | ## Contribution Guides 15 | 16 | We've organized our contribution guidelines into focused documents: 17 | 18 | - [Environment Setup Guide](docs/contributing/environment_setup.md) - How to set up your development environment 19 | - [Coding Guidelines](docs/contributing/coding_guidelines.md) - Code style, best practices, and technical requirements 20 | 21 | ## Contribution Process 22 | 23 | 1. **Find an Issue**: Look for open issues or propose a new one. For newcomers, look for issues labeled "good first issue" 24 | 2. **Set Up**: Follow our [Environment Setup Guide](docs/contributing/environment_setup.md) 25 | 3. **Create a Branch**: `git checkout -b feature/my-new-feature` 26 | 4. **Make Changes**: Write code following our [Coding Guidelines](docs/contributing/coding_guidelines.md) 27 | 5. **Test**: Add tests and ensure all tests pass 28 | 6. **Submit**: Open a Pull Request with a clear description of your changes 29 | 30 | ## Code Review Process 31 | 32 | 1. Maintainers will review your contribution 33 | 2. Address any feedback or questions 34 | 3. Once approved, your code will be merged 35 | 36 | ## Project Structure 37 | 38 | - `albumentations/` - Main source code 39 | - `tests/` - Test suite 40 | - `docs/` - Documentation 41 | 42 | ## Getting Help 43 | 44 | - Join our [Discord community](https://discord.gg/e6zHCXTvaN) 45 | - Open a GitHub [issue](https://github.com/albumentations-team/albumentations/issues) 46 | - Ask questions in your pull request 47 | 48 | ## License 49 | 50 | By contributing, you agree that your contributions will be licensed under the project's MIT License. 51 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Vladimir Iglovikov, Alexander Buslaev, Alexander Parinov, 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MAINTAINERS.md: -------------------------------------------------------------------------------- 1 | # Project Maintainers 2 | 3 | ## Current Maintainer 4 | 5 | Vladimir Iglovikov 6 | 7 | ## Emeritus Team Members 8 | 9 | - Alexander Buslaev 10 | - Alex Parinov 11 | - Eugene Khvedchenya 12 | - Mikhail Druzhinin 13 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | 4 | prune docs/_build 5 | global-exclude docs/augs_overview/*/images/*.jpg 6 | 7 | global-exclude *.py[co] .DS_Store 8 | 9 | # Exclude test, tools, and benchmark directories 10 | prune tests 11 | prune tools 12 | prune benchmark 13 | prune conda.recipe 14 | prune codecov.yaml 15 | prune pre-commit-config.yaml 16 | prune .github 17 | prune requirements-dev.txt 18 | -------------------------------------------------------------------------------- /albumentations/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import metadata 2 | 3 | try: 4 | _metadata = metadata("albumentations") 5 | __version__ = _metadata["Version"] 6 | __author__ = _metadata["Author"] 7 | __maintainer__ = _metadata["Maintainer"] 8 | except Exception: # noqa: BLE001 9 | __version__ = "unknown" 10 | __author__ = "Vladimir Iglovikov" 11 | __maintainer__ = "Vladimir Iglovikov" 12 | 13 | import os 14 | from contextlib import suppress 15 | 16 | from albumentations.check_version import check_for_updates 17 | 18 | from .augmentations import * 19 | from .core.composition import * 20 | from .core.serialization import * 21 | from .core.transforms_interface import * 22 | 23 | with suppress(ImportError): 24 | from .pytorch import * 25 | 26 | # Perform the version check after all other initializations 27 | if os.getenv("NO_ALBUMENTATIONS_UPDATE", "").lower() not in {"true", "1"}: 28 | check_for_updates() 29 | -------------------------------------------------------------------------------- /albumentations/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from .blur.transforms import * 2 | from .crops.transforms import * 3 | from .dropout.channel_dropout import * 4 | from .dropout.coarse_dropout import * 5 | from .dropout.grid_dropout import * 6 | from .dropout.mask_dropout import * 7 | from .dropout.transforms import * 8 | from .dropout.xy_masking import * 9 | from .geometric.distortion import * 10 | from .geometric.flip import * 11 | from .geometric.pad import * 12 | from .geometric.resize import * 13 | from .geometric.rotate import * 14 | from .geometric.transforms import * 15 | from .mixing.domain_adaptation import * 16 | from .mixing.transforms import * 17 | from .other.lambda_transform import * 18 | from .other.type_transform import * 19 | from .pixel.transforms import * 20 | from .spectrogram.transform import * 21 | from .text.transforms import * 22 | from .transforms3d.transforms import * 23 | from .utils import * 24 | -------------------------------------------------------------------------------- /albumentations/augmentations/blur/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/albumentations-team/albumentations/46cc4efd69df585dc355ba0b92b0faf6122e73a6/albumentations/augmentations/blur/__init__.py -------------------------------------------------------------------------------- /albumentations/augmentations/crops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/albumentations-team/albumentations/46cc4efd69df585dc355ba0b92b0faf6122e73a6/albumentations/augmentations/crops/__init__.py -------------------------------------------------------------------------------- /albumentations/augmentations/dropout/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/albumentations-team/albumentations/46cc4efd69df585dc355ba0b92b0faf6122e73a6/albumentations/augmentations/dropout/__init__.py -------------------------------------------------------------------------------- /albumentations/augmentations/dropout/channel_dropout.py: -------------------------------------------------------------------------------- 1 | """Implementation of the Channel Dropout transform for multi-channel images. 2 | 3 | This module provides the ChannelDropout transform, which randomly drops (sets to a fill value) 4 | one or more channels in multi-channel images. This augmentation can help models become more 5 | robust to missing or corrupted channel information and encourage learning from all available 6 | channels rather than relying on a subset. 7 | """ 8 | 9 | from __future__ import annotations 10 | 11 | from typing import Annotated, Any 12 | 13 | import numpy as np 14 | from albucore import get_num_channels 15 | from pydantic import AfterValidator 16 | 17 | from albumentations.core.pydantic import check_range_bounds 18 | from albumentations.core.transforms_interface import BaseTransformInitSchema, ImageOnlyTransform 19 | 20 | from .functional import channel_dropout 21 | 22 | __all__ = ["ChannelDropout"] 23 | 24 | MIN_DROPOUT_CHANNEL_LIST_LENGTH = 2 25 | 26 | 27 | class ChannelDropout(ImageOnlyTransform): 28 | """Randomly drop channels in the input image. 29 | 30 | This transform randomly selects a number of channels to drop from the input image 31 | and replaces them with a specified fill value. This can improve model robustness 32 | to missing or corrupted channels. 33 | 34 | The technique is conceptually similar to: 35 | - Dropout layers in neural networks, which randomly set input units to 0 during training. 36 | - CoarseDropout augmentation, which drops out regions in the spatial dimensions of the image. 37 | 38 | However, ChannelDropout operates on the channel dimension, effectively "dropping out" 39 | entire color channels or feature maps. 40 | 41 | Args: 42 | channel_drop_range (tuple[int, int]): Range from which to choose the number 43 | of channels to drop. The actual number will be randomly selected from 44 | the inclusive range [min, max]. Default: (1, 1). 45 | fill (float): Pixel value used to fill the dropped channels. 46 | Default: 0. 47 | p (float): Probability of applying the transform. Must be in the range 48 | [0, 1]. Default: 0.5. 49 | 50 | Raises: 51 | NotImplementedError: If the input image has only one channel. 52 | ValueError: If the upper bound of channel_drop_range is greater than or 53 | equal to the number of channels in the input image. 54 | 55 | Targets: 56 | image, volume 57 | 58 | Image types: 59 | uint8, float32 60 | 61 | Examples: 62 | >>> import numpy as np 63 | >>> import albumentations as A 64 | >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8) 65 | >>> transform = A.ChannelDropout(channel_drop_range=(1, 2), fill=128, p=1.0) 66 | >>> result = transform(image=image) 67 | >>> dropped_image = result['image'] 68 | >>> assert dropped_image.shape == image.shape 69 | >>> assert np.any(dropped_image != image) # Some channels should be different 70 | 71 | Note: 72 | - The number of channels to drop is randomly chosen within the specified range. 73 | - Channels are randomly selected for dropping. 74 | - This transform is not applicable to single-channel (grayscale) images. 75 | - The transform will raise an error if it's not possible to drop the specified 76 | number of channels (e.g., trying to drop 3 channels from an RGB image). 77 | - This augmentation can be particularly useful for training models to be robust 78 | against missing or corrupted channel data in multi-spectral or hyperspectral imagery. 79 | 80 | """ 81 | 82 | class InitSchema(BaseTransformInitSchema): 83 | channel_drop_range: Annotated[tuple[int, int], AfterValidator(check_range_bounds(1, None))] 84 | fill: float 85 | 86 | def __init__( 87 | self, 88 | channel_drop_range: tuple[int, int] = (1, 1), 89 | fill: float = 0, 90 | p: float = 0.5, 91 | ): 92 | super().__init__(p=p) 93 | 94 | self.channel_drop_range = channel_drop_range 95 | self.fill = fill 96 | 97 | def apply(self, img: np.ndarray, channels_to_drop: list[int], **params: Any) -> np.ndarray: 98 | """Apply channel dropout to the image. 99 | 100 | Args: 101 | img (np.ndarray): Image to apply channel dropout to. 102 | channels_to_drop (list[int]): List of channel indices to drop. 103 | **params (Any): Additional parameters. 104 | 105 | Returns: 106 | np.ndarray: Image with dropped channels. 107 | 108 | """ 109 | return channel_dropout(img, channels_to_drop, self.fill) 110 | 111 | def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, list[int]]: 112 | """Get parameters that depend on input data. 113 | 114 | Args: 115 | params (dict[str, Any]): Parameters. 116 | data (dict[str, Any]): Input data. 117 | 118 | Returns: 119 | dict[str, list[int]]: Dictionary with channels to drop. 120 | 121 | """ 122 | image = data["image"] if "image" in data else data["images"][0] 123 | num_channels = get_num_channels(image) 124 | if num_channels == 1: 125 | msg = "Images has one channel. ChannelDropout is not defined." 126 | raise NotImplementedError(msg) 127 | 128 | if self.channel_drop_range[1] >= num_channels: 129 | msg = "Can not drop all channels in ChannelDropout." 130 | raise ValueError(msg) 131 | num_drop_channels = self.py_random.randint(*self.channel_drop_range) 132 | channels_to_drop = self.py_random.sample(range(num_channels), k=num_drop_channels) 133 | 134 | return {"channels_to_drop": channels_to_drop} 135 | -------------------------------------------------------------------------------- /albumentations/augmentations/dropout/grid_dropout.py: -------------------------------------------------------------------------------- 1 | """Implementation of grid-based dropout augmentation. 2 | 3 | This module provides GridDropout, which creates a regular grid over the image and drops out 4 | rectangular regions according to the specified grid pattern. Unlike random dropout methods, 5 | grid dropout enforces a structured pattern of occlusions that can help models learn spatial 6 | relationships and context across the entire image space. 7 | """ 8 | 9 | from __future__ import annotations 10 | 11 | from typing import Annotated, Any, Literal 12 | 13 | from pydantic import AfterValidator, Field 14 | 15 | import albumentations.augmentations.dropout.functional as fdropout 16 | from albumentations.augmentations.dropout.transforms import BaseDropout 17 | from albumentations.core.pydantic import check_range_bounds, nondecreasing 18 | 19 | __all__ = ["GridDropout"] 20 | 21 | 22 | class GridDropout(BaseDropout): 23 | """Apply GridDropout augmentation to images, masks, bounding boxes, and keypoints. 24 | 25 | GridDropout drops out rectangular regions of an image and the corresponding mask in a grid fashion. 26 | This technique can help improve model robustness by forcing the network to rely on a broader context 27 | rather than specific local features. 28 | 29 | Args: 30 | ratio (float): The ratio of the mask holes to the unit size (same for horizontal and vertical directions). 31 | Must be between 0 and 1. Default: 0.5. 32 | unit_size_range (tuple[int, int] | None): Range from which to sample grid size. Default: None. 33 | Must be between 2 and the image's shorter edge. If None, grid size is calculated based on image size. 34 | holes_number_xy (tuple[int, int] | None): The number of grid units in x and y directions. 35 | First value should be between 1 and image width//2, 36 | Second value should be between 1 and image height//2. 37 | Default: None. If provided, overrides unit_size_range. 38 | random_offset (bool): Whether to offset the grid randomly between 0 and (grid unit size - hole size). 39 | If True, entered shift_xy is ignored and set randomly. Default: True. 40 | fill (tuple[float, float] | float | Literal["random", "random_uniform", "inpaint_telea", "inpaint_ns"]): 41 | Value for the dropped pixels. Can be: 42 | - int or float: all channels are filled with this value 43 | - tuple: tuple of values for each channel 44 | - 'random': each pixel is filled with random values 45 | - 'random_uniform': each hole is filled with a single random color 46 | - 'inpaint_telea': uses OpenCV Telea inpainting method 47 | - 'inpaint_ns': uses OpenCV Navier-Stokes inpainting method 48 | Default: 0 49 | fill_mask (tuple[float, float] | float | None): Value for the dropped pixels in mask. 50 | If None, the mask is not modified. Default: None. 51 | shift_xy (tuple[int, int]): Offsets of the grid start in x and y directions from (0,0) coordinate. 52 | Only used when random_offset is False. Default: (0, 0). 53 | p (float): Probability of applying the transform. Default: 0.5. 54 | 55 | Targets: 56 | image, mask, bboxes, keypoints, volume, mask3d 57 | 58 | Image types: 59 | uint8, float32 60 | 61 | Note: 62 | - If both unit_size_range and holes_number_xy are None, the grid size is calculated based on the image size. 63 | - The actual number of dropped regions may differ slightly from holes_number_xy due to rounding. 64 | - Inpainting methods ('inpaint_telea', 'inpaint_ns') work only with grayscale or RGB images. 65 | - For 'random_uniform' fill, each grid cell gets a single random color, unlike 'random' where each pixel 66 | gets its own random value. 67 | 68 | Example: 69 | >>> import numpy as np 70 | >>> import albumentations as A 71 | >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8) 72 | >>> mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8) 73 | >>> # Example with standard fill value 74 | >>> aug_basic = A.GridDropout( 75 | ... ratio=0.3, 76 | ... unit_size_range=(10, 20), 77 | ... random_offset=True, 78 | ... p=1.0 79 | ... ) 80 | >>> # Example with random uniform fill 81 | >>> aug_random = A.GridDropout( 82 | ... ratio=0.3, 83 | ... unit_size_range=(10, 20), 84 | ... fill="random_uniform", 85 | ... p=1.0 86 | ... ) 87 | >>> # Example with inpainting 88 | >>> aug_inpaint = A.GridDropout( 89 | ... ratio=0.3, 90 | ... unit_size_range=(10, 20), 91 | ... fill="inpaint_ns", 92 | ... p=1.0 93 | ... ) 94 | >>> transformed = aug_random(image=image, mask=mask) 95 | >>> transformed_image, transformed_mask = transformed["image"], transformed["mask"] 96 | 97 | Reference: 98 | - Paper: https://arxiv.org/abs/2001.04086 99 | - OpenCV Inpainting methods: https://docs.opencv.org/master/df/d3d/tutorial_py_inpainting.html 100 | 101 | """ 102 | 103 | class InitSchema(BaseDropout.InitSchema): 104 | ratio: float = Field(gt=0, le=1) 105 | 106 | random_offset: bool 107 | 108 | unit_size_range: ( 109 | Annotated[tuple[int, int], AfterValidator(check_range_bounds(2, None)), AfterValidator(nondecreasing)] 110 | | None 111 | ) 112 | shift_xy: Annotated[tuple[int, int], AfterValidator(check_range_bounds(0, None))] 113 | 114 | holes_number_xy: Annotated[tuple[int, int], AfterValidator(check_range_bounds(1, None))] | None 115 | 116 | def __init__( 117 | self, 118 | ratio: float = 0.5, 119 | random_offset: bool = True, 120 | unit_size_range: tuple[int, int] | None = None, 121 | holes_number_xy: tuple[int, int] | None = None, 122 | shift_xy: tuple[int, int] = (0, 0), 123 | fill: tuple[float, ...] | float | Literal["random", "random_uniform", "inpaint_telea", "inpaint_ns"] = 0, 124 | fill_mask: tuple[float, ...] | float | None = None, 125 | p: float = 0.5, 126 | ): 127 | super().__init__(fill=fill, fill_mask=fill_mask, p=p) 128 | self.ratio = ratio 129 | self.unit_size_range = unit_size_range 130 | self.holes_number_xy = holes_number_xy 131 | self.random_offset = random_offset 132 | self.shift_xy = shift_xy 133 | 134 | def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]: 135 | """Get parameters dependent on the data. 136 | 137 | Args: 138 | params (dict[str, Any]): Dictionary containing parameters. 139 | data (dict[str, Any]): Dictionary containing data. 140 | 141 | Returns: 142 | dict[str, Any]: Dictionary with parameters for transformation. 143 | 144 | """ 145 | image_shape = params["shape"] 146 | if self.holes_number_xy: 147 | grid = self.holes_number_xy 148 | else: 149 | # Calculate grid based on unit_size_range or default 150 | unit_height, unit_width = fdropout.calculate_grid_dimensions( 151 | image_shape, 152 | self.unit_size_range, 153 | self.holes_number_xy, 154 | self.random_generator, 155 | ) 156 | grid = (image_shape[0] // unit_height, image_shape[1] // unit_width) 157 | 158 | holes = fdropout.generate_grid_holes( 159 | image_shape, 160 | grid, 161 | self.ratio, 162 | self.random_offset, 163 | self.shift_xy, 164 | self.random_generator, 165 | ) 166 | return {"holes": holes, "seed": self.random_generator.integers(0, 2**32 - 1)} 167 | -------------------------------------------------------------------------------- /albumentations/augmentations/dropout/xy_masking.py: -------------------------------------------------------------------------------- 1 | """Implementation of XY masking for time-frequency domain transformations. 2 | 3 | This module provides the XYMasking transform, which applies masking strips along the X and Y axes 4 | of an image. This is particularly useful for audio spectrograms, time-series data visualizations, 5 | and other grid-like data representations where masking in specific directions (time or frequency) 6 | can improve model robustness and generalization. 7 | """ 8 | 9 | from __future__ import annotations 10 | 11 | from typing import Any, Literal, cast 12 | 13 | import numpy as np 14 | from pydantic import model_validator 15 | from typing_extensions import Self 16 | 17 | from albumentations.augmentations.dropout.transforms import BaseDropout 18 | from albumentations.core.pydantic import NonNegativeIntRangeType 19 | from albumentations.core.transforms_interface import BaseTransformInitSchema 20 | 21 | __all__ = ["XYMasking"] 22 | 23 | 24 | class XYMasking(BaseDropout): 25 | """Applies masking strips to an image, either horizontally (X axis) or vertically (Y axis), 26 | simulating occlusions. This transform is useful for training models to recognize images 27 | with varied visibility conditions. It's particularly effective for spectrogram images, 28 | allowing spectral and frequency masking to improve model robustness. 29 | 30 | At least one of `max_x_length` or `max_y_length` must be specified, dictating the mask's 31 | maximum size along each axis. 32 | 33 | Args: 34 | num_masks_x (int | tuple[int, int]): Number or range of horizontal regions to mask. Defaults to 0. 35 | num_masks_y (int | tuple[int, int]): Number or range of vertical regions to mask. Defaults to 0. 36 | mask_x_length (int | tuple[int, int]): Specifies the length of the masks along 37 | the X (horizontal) axis. If an integer is provided, it sets a fixed mask length. 38 | If a tuple of two integers (min, max) is provided, 39 | the mask length is randomly chosen within this range for each mask. 40 | This allows for variable-length masks in the horizontal direction. 41 | mask_y_length (int | tuple[int, int]): Specifies the height of the masks along 42 | the Y (vertical) axis. Similar to `mask_x_length`, an integer sets a fixed mask height, 43 | while a tuple (min, max) allows for variable-height masks, chosen randomly 44 | within the specified range for each mask. This flexibility facilitates creating masks of various 45 | sizes in the vertical direction. 46 | fill (tuple[float, float] | float | Literal["random", "random_uniform", "inpaint_telea", "inpaint_ns"]): 47 | Value for the dropped pixels. Can be: 48 | - int or float: all channels are filled with this value 49 | - tuple: tuple of values for each channel 50 | - 'random': each pixel is filled with random values 51 | - 'random_uniform': each hole is filled with a single random color 52 | - 'inpaint_telea': uses OpenCV Telea inpainting method 53 | - 'inpaint_ns': uses OpenCV Navier-Stokes inpainting method 54 | Default: 0 55 | fill_mask (tuple[float, float] | float | None): Fill value for dropout regions in the mask. 56 | If None, mask regions corresponding to image dropouts are unchanged. Default: None 57 | p (float): Probability of applying the transform. Defaults to 0.5. 58 | 59 | Targets: 60 | image, mask, bboxes, keypoints, volume, mask3d 61 | 62 | Image types: 63 | uint8, float32 64 | 65 | Note: Either `max_x_length` or `max_y_length` or both must be defined. 66 | 67 | """ 68 | 69 | class InitSchema(BaseTransformInitSchema): 70 | num_masks_x: NonNegativeIntRangeType 71 | num_masks_y: NonNegativeIntRangeType 72 | mask_x_length: NonNegativeIntRangeType 73 | mask_y_length: NonNegativeIntRangeType 74 | 75 | fill: tuple[float, ...] | float | Literal["random", "random_uniform", "inpaint_telea", "inpaint_ns"] 76 | fill_mask: tuple[float, ...] | float | None 77 | 78 | @model_validator(mode="after") 79 | def _check_mask_length(self) -> Self: 80 | if ( 81 | isinstance(self.mask_x_length, int) 82 | and self.mask_x_length <= 0 83 | and isinstance(self.mask_y_length, int) 84 | and self.mask_y_length <= 0 85 | ): 86 | msg = "At least one of `mask_x_length` or `mask_y_length` Should be a positive number." 87 | raise ValueError(msg) 88 | 89 | return self 90 | 91 | def __init__( 92 | self, 93 | num_masks_x: tuple[int, int] | int = 0, 94 | num_masks_y: tuple[int, int] | int = 0, 95 | mask_x_length: tuple[int, int] | int = 0, 96 | mask_y_length: tuple[int, int] | int = 0, 97 | fill: tuple[float, ...] | float | Literal["random", "random_uniform", "inpaint_telea", "inpaint_ns"] = 0, 98 | fill_mask: tuple[float, ...] | float | None = None, 99 | p: float = 0.5, 100 | ): 101 | super().__init__(p=p, fill=fill, fill_mask=fill_mask) 102 | self.num_masks_x = cast("tuple[int, int]", num_masks_x) 103 | self.num_masks_y = cast("tuple[int, int]", num_masks_y) 104 | 105 | self.mask_x_length = cast("tuple[int, int]", mask_x_length) 106 | self.mask_y_length = cast("tuple[int, int]", mask_y_length) 107 | 108 | def _validate_mask_length( 109 | self, 110 | mask_length: tuple[int, int] | None, 111 | dimension_size: int, 112 | dimension_name: str, 113 | ) -> None: 114 | """Validate the mask length against the corresponding image dimension size.""" 115 | if mask_length is not None: 116 | if isinstance(mask_length, (tuple, list)): 117 | if mask_length[0] < 0 or mask_length[1] > dimension_size: 118 | raise ValueError( 119 | f"{dimension_name} range {mask_length} is out of valid range [0, {dimension_size}]", 120 | ) 121 | elif mask_length < 0 or mask_length > dimension_size: 122 | raise ValueError(f"{dimension_name} {mask_length} exceeds image {dimension_name} {dimension_size}") 123 | 124 | def get_params_dependent_on_data( 125 | self, 126 | params: dict[str, Any], 127 | data: dict[str, Any], 128 | ) -> dict[str, np.ndarray]: 129 | """Get parameters dependent on the data. 130 | 131 | Args: 132 | params (dict[str, Any]): Dictionary containing parameters. 133 | data (dict[str, Any]): Dictionary containing data. 134 | 135 | Returns: 136 | dict[str, np.ndarray]: Dictionary with parameters for transformation. 137 | 138 | """ 139 | image_shape = params["shape"][:2] 140 | 141 | height, width = image_shape 142 | 143 | self._validate_mask_length(self.mask_x_length, width, "mask_x_length") 144 | self._validate_mask_length(self.mask_y_length, height, "mask_y_length") 145 | 146 | masks_x = self._generate_masks(self.num_masks_x, image_shape, self.mask_x_length, axis="x") 147 | masks_y = self._generate_masks(self.num_masks_y, image_shape, self.mask_y_length, axis="y") 148 | 149 | holes = np.array(masks_x + masks_y) 150 | 151 | return {"holes": holes, "seed": self.random_generator.integers(0, 2**32 - 1)} 152 | 153 | def _generate_mask_size(self, mask_length: tuple[int, int]) -> int: 154 | return self.py_random.randint(*mask_length) 155 | 156 | def _generate_masks( 157 | self, 158 | num_masks: tuple[int, int], 159 | image_shape: tuple[int, int], 160 | max_length: tuple[int, int] | None, 161 | axis: str, 162 | ) -> list[tuple[int, int, int, int]]: 163 | if max_length is None or max_length == 0 or (isinstance(num_masks, (int, float)) and num_masks == 0): 164 | return [] 165 | 166 | masks = [] 167 | num_masks_integer = ( 168 | num_masks if isinstance(num_masks, int) else self.py_random.randint(num_masks[0], num_masks[1]) 169 | ) 170 | 171 | height, width = image_shape 172 | 173 | for _ in range(num_masks_integer): 174 | length = self._generate_mask_size(max_length) 175 | 176 | if axis == "x": 177 | x_min = self.py_random.randint(0, width - length) 178 | y_min = 0 179 | x_max, y_max = x_min + length, height 180 | else: # axis == 'y' 181 | y_min = self.py_random.randint(0, height - length) 182 | x_min = 0 183 | x_max, y_max = width, y_min + length 184 | 185 | masks.append((x_min, y_min, x_max, y_max)) 186 | return masks 187 | -------------------------------------------------------------------------------- /albumentations/augmentations/geometric/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/albumentations-team/albumentations/46cc4efd69df585dc355ba0b92b0faf6122e73a6/albumentations/augmentations/geometric/__init__.py -------------------------------------------------------------------------------- /albumentations/augmentations/mixing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/albumentations-team/albumentations/46cc4efd69df585dc355ba0b92b0faf6122e73a6/albumentations/augmentations/mixing/__init__.py -------------------------------------------------------------------------------- /albumentations/augmentations/other/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/albumentations-team/albumentations/46cc4efd69df585dc355ba0b92b0faf6122e73a6/albumentations/augmentations/other/__init__.py -------------------------------------------------------------------------------- /albumentations/augmentations/other/lambda_transform.py: -------------------------------------------------------------------------------- 1 | """Lambda transform module for creating custom user-defined transformations. 2 | 3 | This module provides a flexible transform class that allows users to define their own 4 | custom transformation functions for different targets (image, mask, keypoints, bboxes). 5 | It's particularly useful for implementing custom logic that isn't available in the 6 | standard transforms. 7 | 8 | The Lambda transform accepts different callable functions for each target type and 9 | applies them when the transform is executed. This allows for maximum flexibility 10 | while maintaining compatibility with the Albumentations pipeline structure. 11 | 12 | Key features: 13 | - Apply different custom functions to different target types 14 | - Compatible with all Albumentations pipeline features 15 | - Support for all image types and formats 16 | - Ability to handle any number of channels 17 | - Warning system for lambda expressions and multiprocessing compatibility 18 | 19 | Note that using actual lambda expressions (rather than named functions) can cause 20 | issues with multiprocessing, as lambdas cannot be properly pickled. 21 | """ 22 | 23 | from __future__ import annotations 24 | 25 | import warnings 26 | from types import LambdaType 27 | from typing import Any, Callable 28 | 29 | import numpy as np 30 | 31 | from albumentations.augmentations.pixel import functional as fpixel 32 | from albumentations.core.transforms_interface import NoOp 33 | from albumentations.core.utils import format_args 34 | 35 | __all__ = ["Lambda"] 36 | 37 | 38 | class Lambda(NoOp): 39 | """A flexible transformation class for using user-defined transformation functions per targets. 40 | Function signature must include **kwargs to accept optional arguments like interpolation method, image size, etc: 41 | 42 | Args: 43 | image (Callable[..., Any] | None): Image transformation function. 44 | mask (Callable[..., Any] | None): Mask transformation function. 45 | keypoints (Callable[..., Any] | None): Keypoints transformation function. 46 | bboxes (Callable[..., Any] | None): BBoxes transformation function. 47 | p (float): probability of applying the transform. Default: 1.0. 48 | 49 | Targets: 50 | image, mask, bboxes, keypoints, volume, mask3d 51 | 52 | Image types: 53 | uint8, float32 54 | 55 | Number of channels: 56 | Any 57 | 58 | """ 59 | 60 | def __init__( 61 | self, 62 | image: Callable[..., Any] | None = None, 63 | mask: Callable[..., Any] | None = None, 64 | keypoints: Callable[..., Any] | None = None, 65 | bboxes: Callable[..., Any] | None = None, 66 | name: str | None = None, 67 | p: float = 1.0, 68 | ): 69 | super().__init__(p=p) 70 | 71 | self.name = name 72 | self.custom_apply_fns = dict.fromkeys(("image", "mask", "keypoints", "bboxes"), fpixel.noop) 73 | for target_name, custom_apply_fn in { 74 | "image": image, 75 | "mask": mask, 76 | "keypoints": keypoints, 77 | "bboxes": bboxes, 78 | }.items(): 79 | if custom_apply_fn is not None: 80 | if isinstance(custom_apply_fn, LambdaType) and custom_apply_fn.__name__ == "": 81 | warnings.warn( 82 | "Using lambda is incompatible with multiprocessing. " 83 | "Consider using regular functions or partial().", 84 | stacklevel=2, 85 | ) 86 | 87 | self.custom_apply_fns[target_name] = custom_apply_fn 88 | 89 | def apply(self, img: np.ndarray, **params: Any) -> np.ndarray: 90 | """Apply the Lambda transform to the input image. 91 | 92 | Args: 93 | img (np.ndarray): The input image to apply the Lambda transform to. 94 | **params (Any): Additional parameters (not used in this transform). 95 | 96 | Returns: 97 | np.ndarray: The image with the applied Lambda transform. 98 | 99 | """ 100 | fn = self.custom_apply_fns["image"] 101 | return fn(img, **params) 102 | 103 | def apply_to_mask(self, mask: np.ndarray, **params: Any) -> np.ndarray: 104 | """Apply the Lambda transform to the input mask. 105 | 106 | Args: 107 | mask (np.ndarray): The input mask to apply the Lambda transform to. 108 | **params (Any): Additional parameters (not used in this transform). 109 | 110 | Returns: 111 | np.ndarray: The mask with the applied Lambda transform. 112 | 113 | """ 114 | fn = self.custom_apply_fns["mask"] 115 | return fn(mask, **params) 116 | 117 | def apply_to_bboxes(self, bboxes: np.ndarray, **params: Any) -> np.ndarray: 118 | """Apply the Lambda transform to the input bounding boxes. 119 | 120 | Args: 121 | bboxes (np.ndarray): The input bounding boxes to apply the Lambda transform to. 122 | **params (Any): Additional parameters (not used in this transform). 123 | 124 | Returns: 125 | np.ndarray: The bounding boxes with the applied Lambda transform. 126 | 127 | """ 128 | fn = self.custom_apply_fns["bboxes"] 129 | return fn(bboxes, **params) 130 | 131 | def apply_to_keypoints(self, keypoints: np.ndarray, **params: Any) -> np.ndarray: 132 | """Apply the Lambda transform to the input keypoints. 133 | 134 | Args: 135 | keypoints (np.ndarray): The input keypoints to apply the Lambda transform to. 136 | **params (Any): Additional parameters (not used in this transform). 137 | 138 | Returns: 139 | np.ndarray: The keypoints with the applied Lambda transform. 140 | 141 | """ 142 | fn = self.custom_apply_fns["keypoints"] 143 | return fn(keypoints, **params) 144 | 145 | @classmethod 146 | def is_serializable(cls) -> bool: 147 | """Check if the Lambda transform is serializable. 148 | 149 | Returns: 150 | bool: True if the transform is serializable, False otherwise. 151 | 152 | """ 153 | return False 154 | 155 | def to_dict_private(self) -> dict[str, Any]: 156 | """Convert the Lambda transform to a dictionary. 157 | 158 | Returns: 159 | dict[str, Any]: The dictionary representation of the transform. 160 | 161 | """ 162 | if self.name is None: 163 | msg = ( 164 | "To make a Lambda transform serializable you should provide the `name` argument, " 165 | "e.g. `Lambda(name='my_transform', image=, ...)`." 166 | ) 167 | raise ValueError(msg) 168 | return {"__class_fullname__": self.get_class_fullname(), "__name__": self.name} 169 | 170 | def __repr__(self) -> str: 171 | """Return the string representation of the Lambda transform. 172 | 173 | Returns: 174 | str: The string representation of the Lambda transform. 175 | 176 | """ 177 | state = {"name": self.name} 178 | state.update(self.custom_apply_fns.items()) # type: ignore[arg-type] 179 | state.update(self.get_base_init_args()) 180 | return f"{self.__class__.__name__}({format_args(state)})" 181 | -------------------------------------------------------------------------------- /albumentations/augmentations/pixel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/albumentations-team/albumentations/46cc4efd69df585dc355ba0b92b0faf6122e73a6/albumentations/augmentations/pixel/__init__.py -------------------------------------------------------------------------------- /albumentations/augmentations/spectrogram/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/albumentations-team/albumentations/46cc4efd69df585dc355ba0b92b0faf6122e73a6/albumentations/augmentations/spectrogram/__init__.py -------------------------------------------------------------------------------- /albumentations/augmentations/spectrogram/transform.py: -------------------------------------------------------------------------------- 1 | """Transforms for spectrogram augmentation. 2 | 3 | This module provides transforms specifically designed for augmenting spectrograms 4 | in audio processing tasks. Includes time reversal, time masking, and frequency 5 | masking transforms commonly used in audio machine learning applications. 6 | """ 7 | 8 | from __future__ import annotations 9 | 10 | from warnings import warn 11 | 12 | from pydantic import Field 13 | 14 | from albumentations.augmentations.dropout.xy_masking import XYMasking 15 | from albumentations.augmentations.geometric.flip import HorizontalFlip 16 | from albumentations.core.transforms_interface import BaseTransformInitSchema 17 | from albumentations.core.type_definitions import ALL_TARGETS 18 | 19 | __all__ = [ 20 | "FrequencyMasking", 21 | "TimeMasking", 22 | "TimeReverse", 23 | ] 24 | 25 | 26 | class TimeReverse(HorizontalFlip): 27 | """Reverse the time axis of a spectrogram image, also known as time inversion. 28 | 29 | Time inversion of a spectrogram is analogous to the random flip of an image, 30 | an augmentation technique widely used in the visual domain. This can be relevant 31 | in the context of audio classification tasks when working with spectrograms. 32 | The technique was successfully applied in the AudioCLIP paper, which extended 33 | CLIP to handle image, text, and audio inputs. 34 | 35 | This transform is implemented as a subclass of HorizontalFlip since reversing 36 | time in a spectrogram is equivalent to flipping the image horizontally. 37 | 38 | Args: 39 | p (float): probability of applying the transform. Default: 0.5. 40 | 41 | Targets: 42 | image, mask, bboxes, keypoints, volume, mask3d 43 | 44 | Image types: 45 | uint8, float32 46 | 47 | Number of channels: 48 | Any 49 | 50 | Note: 51 | This transform is functionally identical to HorizontalFlip but provides 52 | a more semantically meaningful name when working with spectrograms and 53 | other time-series visualizations. 54 | 55 | References: 56 | - AudioCLIP paper: https://arxiv.org/abs/2106.13043 57 | - Audiomentations: https://iver56.github.io/audiomentations/waveform_transforms/reverse/ 58 | 59 | """ 60 | 61 | _targets = ALL_TARGETS 62 | 63 | class InitSchema(BaseTransformInitSchema): 64 | pass 65 | 66 | def __init__( 67 | self, 68 | p: float = 0.5, 69 | ): 70 | warn( 71 | "TimeReverse is an alias for HorizontalFlip transform. " 72 | "Consider using HorizontalFlip directly from albumentations.HorizontalFlip. ", 73 | UserWarning, 74 | stacklevel=2, 75 | ) 76 | super().__init__(p=p) 77 | 78 | 79 | class TimeMasking(XYMasking): 80 | """Apply masking to a spectrogram in the time domain. 81 | 82 | This transform masks random segments along the time axis of a spectrogram, 83 | implementing the time masking technique proposed in the SpecAugment paper. 84 | Time masking helps in training models to be robust against temporal variations 85 | and missing information in audio signals. 86 | 87 | This is a specialized version of XYMasking configured for time masking only. 88 | For more advanced use cases (e.g., multiple masks, frequency masking, or custom 89 | fill values), consider using XYMasking directly. 90 | 91 | Args: 92 | time_mask_param (int): Maximum possible length of the mask in the time domain. 93 | Must be a positive integer. Length of the mask is uniformly sampled from (0, time_mask_param). 94 | p (float): probability of applying the transform. Default: 0.5. 95 | 96 | Targets: 97 | image, mask, bboxes, keypoints, volume, mask3d 98 | 99 | Image types: 100 | uint8, float32 101 | 102 | Number of channels: 103 | Any 104 | 105 | Note: 106 | This transform is implemented as a subset of XYMasking with fixed parameters: 107 | - Single horizontal mask (num_masks_x=1) 108 | - No vertical masks (num_masks_y=0) 109 | - Zero fill value 110 | - Random mask length up to time_mask_param 111 | 112 | For more flexibility, including: 113 | - Multiple masks 114 | - Custom fill values 115 | - Frequency masking 116 | - Combined time-frequency masking 117 | Consider using albumentations.XYMasking directly. 118 | 119 | References: 120 | - SpecAugment paper: https://arxiv.org/abs/1904.08779 121 | - Original implementation: https://pytorch.org/audio/stable/transforms.html#timemask 122 | 123 | """ 124 | 125 | class InitSchema(BaseTransformInitSchema): 126 | time_mask_param: int = Field(gt=0) 127 | 128 | def __init__( 129 | self, 130 | time_mask_param: int = 40, 131 | p: float = 0.5, 132 | ): 133 | warn( 134 | "TimeMasking is a specialized version of XYMasking. " 135 | "For more flexibility (multiple masks, custom fill values, frequency masking), " 136 | "consider using XYMasking directly from albumentations.XYMasking.", 137 | UserWarning, 138 | stacklevel=2, 139 | ) 140 | super().__init__( 141 | num_masks_x=1, 142 | num_masks_y=0, 143 | mask_x_length=(0, time_mask_param), 144 | fill=0, 145 | fill_mask=0, 146 | p=p, 147 | ) 148 | self.time_mask_param = time_mask_param 149 | 150 | 151 | class FrequencyMasking(XYMasking): 152 | """Apply masking to a spectrogram in the frequency domain. 153 | 154 | This transform masks random segments along the frequency axis of a spectrogram, 155 | implementing the frequency masking technique proposed in the SpecAugment paper. 156 | Frequency masking helps in training models to be robust against frequency variations 157 | and missing spectral information in audio signals. 158 | 159 | This is a specialized version of XYMasking configured for frequency masking only. 160 | For more advanced use cases (e.g., multiple masks, time masking, or custom 161 | fill values), consider using XYMasking directly. 162 | 163 | Args: 164 | freq_mask_param (int): Maximum possible length of the mask in the frequency domain. 165 | Must be a positive integer. Length of the mask is uniformly sampled from (0, freq_mask_param). 166 | p (float): probability of applying the transform. Default: 0.5. 167 | 168 | Targets: 169 | image, mask, bboxes, keypoints, volume, mask3d 170 | 171 | Image types: 172 | uint8, float32 173 | 174 | Number of channels: 175 | Any 176 | 177 | Note: 178 | This transform is implemented as a subset of XYMasking with fixed parameters: 179 | - Single vertical mask (num_masks_y=1) 180 | - No horizontal masks (num_masks_x=0) 181 | - Zero fill value 182 | - Random mask length up to freq_mask_param 183 | 184 | For more flexibility, including: 185 | - Multiple masks 186 | - Custom fill values 187 | - Time masking 188 | - Combined time-frequency masking 189 | Consider using albumentations.XYMasking directly. 190 | 191 | References: 192 | - SpecAugment paper: https://arxiv.org/abs/1904.08779 193 | - Original implementation: https://pytorch.org/audio/stable/transforms.html#freqmask 194 | 195 | """ 196 | 197 | class InitSchema(BaseTransformInitSchema): 198 | freq_mask_param: int = Field(gt=0) 199 | 200 | def __init__( 201 | self, 202 | freq_mask_param: int = 30, 203 | p: float = 0.5, 204 | ): 205 | warn( 206 | "FrequencyMasking is a specialized version of XYMasking. " 207 | "For more flexibility (multiple masks, custom fill values, time masking), " 208 | "consider using XYMasking directly from albumentations.XYMasking.", 209 | UserWarning, 210 | stacklevel=2, 211 | ) 212 | super().__init__( 213 | p=p, 214 | fill=0, 215 | fill_mask=0, 216 | mask_y_length=(0, freq_mask_param), 217 | num_masks_x=0, 218 | num_masks_y=1, 219 | ) 220 | self.freq_mask_param = freq_mask_param 221 | -------------------------------------------------------------------------------- /albumentations/augmentations/text/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/albumentations-team/albumentations/46cc4efd69df585dc355ba0b92b0faf6122e73a6/albumentations/augmentations/text/__init__.py -------------------------------------------------------------------------------- /albumentations/augmentations/transforms3d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/albumentations-team/albumentations/46cc4efd69df585dc355ba0b92b0faf6122e73a6/albumentations/augmentations/transforms3d/__init__.py -------------------------------------------------------------------------------- /albumentations/check_version.py: -------------------------------------------------------------------------------- 1 | """Module for checking and comparing albumentations package versions. 2 | 3 | This module provides utilities for version checking and comparison, including 4 | the ability to fetch the latest version from PyPI and compare it with the currently 5 | installed version. It helps users stay informed about available updates and 6 | encourages keeping the library up-to-date with the latest features and bug fixes. 7 | """ 8 | 9 | from __future__ import annotations 10 | 11 | import json 12 | import re 13 | import urllib.request 14 | from urllib.request import OpenerDirector 15 | from warnings import warn 16 | 17 | from albumentations import __version__ as current_version 18 | 19 | __version__: str = current_version # type: ignore[has-type, unused-ignore] 20 | 21 | SUCCESS_HTML_CODE = 200 22 | 23 | opener = None 24 | 25 | 26 | def get_opener() -> OpenerDirector: 27 | """Get or create a URL opener for making HTTP requests. 28 | 29 | This function implements a singleton pattern for the opener to avoid 30 | recreating it on each request. It lazily instantiates a URL opener 31 | with HTTP and HTTPS handlers. 32 | 33 | Returns: 34 | OpenerDirector: URL opener instance for making HTTP requests. 35 | 36 | """ 37 | global opener # noqa: PLW0603 38 | if opener is None: 39 | opener = urllib.request.build_opener(urllib.request.HTTPHandler(), urllib.request.HTTPSHandler()) 40 | return opener 41 | 42 | 43 | def fetch_version_info() -> str: 44 | """Fetch version information from PyPI for albumentations package. 45 | 46 | This function retrieves JSON data from PyPI containing information about 47 | the latest available version of albumentations. It handles network errors 48 | gracefully and returns an empty string if the request fails. 49 | 50 | Returns: 51 | str: JSON string containing version information if successful, 52 | empty string otherwise. 53 | 54 | """ 55 | opener = get_opener() 56 | url = "https://pypi.org/pypi/albumentations/json" 57 | try: 58 | with opener.open(url, timeout=2) as response: 59 | if response.status == SUCCESS_HTML_CODE: 60 | data = response.read() 61 | encoding = response.info().get_content_charset("utf-8") 62 | return data.decode(encoding) 63 | except Exception as e: # noqa: BLE001 64 | warn(f"Error fetching version info {e}", stacklevel=2) 65 | return "" 66 | 67 | 68 | def parse_version(data: str) -> str: 69 | """Parses the version from the given JSON data.""" 70 | if data: 71 | try: 72 | json_data = json.loads(data) 73 | # Use .get() to avoid KeyError if 'version' is not present 74 | return json_data.get("info", {}).get("version", "") 75 | except json.JSONDecodeError: 76 | # This will handle malformed JSON data 77 | return "" 78 | return "" 79 | 80 | 81 | def compare_versions(v1: tuple[int | str, ...], v2: tuple[int | str, ...]) -> bool: 82 | """Compare two version tuples. 83 | Returns True if v1 > v2, False otherwise. 84 | 85 | Special rules: 86 | 1. Release version > pre-release version (e.g., (1, 4) > (1, 4, 'beta')) 87 | 2. Numeric parts are compared numerically 88 | 3. String parts are compared lexicographically 89 | """ 90 | # First compare common parts 91 | for p1, p2 in zip(v1, v2): 92 | if p1 != p2: 93 | # If both are same type, direct comparison works 94 | if isinstance(p1, int) and isinstance(p2, int): 95 | return p1 > p2 96 | if isinstance(p1, str) and isinstance(p2, str): 97 | return p1 > p2 98 | # If types differ, numbers are greater (release > pre-release) 99 | return isinstance(p1, int) 100 | 101 | # If we get here, all common parts are equal 102 | # Longer version is greater only if next element is a number 103 | if len(v1) > len(v2): 104 | return isinstance(v1[len(v2)], int) 105 | if len(v2) > len(v1): 106 | # v2 is longer, so v1 is greater only if v2's next part is a string (pre-release) 107 | return isinstance(v2[len(v1)], str) 108 | 109 | return False # Versions are equal 110 | 111 | 112 | def parse_version_parts(version_str: str) -> tuple[int | str, ...]: 113 | """Convert version string to tuple of (int | str) parts following PEP 440 conventions. 114 | 115 | Examples: 116 | "1.4.24" -> (1, 4, 24) 117 | "1.4beta" -> (1, 4, "beta") 118 | "1.4.beta2" -> (1, 4, "beta", 2) 119 | "1.4.alpha2" -> (1, 4, "alpha", 2) 120 | 121 | """ 122 | parts = [] 123 | # First split by dots 124 | for part in version_str.split("."): 125 | # Then parse each part for numbers and letters 126 | segments = re.findall(r"([0-9]+|[a-zA-Z]+)", part) 127 | for segment in segments: 128 | if segment.isdigit(): 129 | parts.append(int(segment)) 130 | else: 131 | parts.append(segment.lower()) 132 | return tuple(parts) 133 | 134 | 135 | def check_for_updates() -> None: 136 | """Check if a newer version of albumentations is available on PyPI. 137 | 138 | This function compares the current installed version with the latest version 139 | available on PyPI. If a newer version is found, it issues a warning to the user 140 | with upgrade instructions. All exceptions are caught to ensure this check 141 | doesn't affect normal package operation. 142 | 143 | The check can be disabled by setting the environment variable 144 | NO_ALBUMENTATIONS_UPDATE to 1. 145 | """ 146 | try: 147 | data = fetch_version_info() 148 | latest_version = parse_version(data) 149 | if latest_version: 150 | latest_parts = parse_version_parts(latest_version) 151 | current_parts = parse_version_parts(current_version) 152 | if compare_versions(latest_parts, current_parts): 153 | warn( 154 | f"A new version of Albumentations is available: {latest_version!r} (you have {current_version!r}). " 155 | "Upgrade using: pip install -U albumentations. " 156 | "To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.", 157 | UserWarning, 158 | stacklevel=2, 159 | ) 160 | except Exception as e: # General exception catch to ensure silent failure # noqa: BLE001 161 | warn( 162 | f"Failed to check for updates due to error: {e}. " 163 | "To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.", 164 | UserWarning, 165 | stacklevel=2, 166 | ) 167 | -------------------------------------------------------------------------------- /albumentations/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/albumentations-team/albumentations/46cc4efd69df585dc355ba0b92b0faf6122e73a6/albumentations/core/__init__.py -------------------------------------------------------------------------------- /albumentations/core/pydantic.py: -------------------------------------------------------------------------------- 1 | """Module containing Pydantic validation utilities for Albumentations. 2 | 3 | This module provides a collection of validators and utility functions used for validating 4 | parameters in the Pydantic models throughout the Albumentations library. It includes 5 | functions for ensuring numeric ranges are valid, handling type conversions, and creating 6 | standardized validation patterns that are reused across the codebase. 7 | """ 8 | 9 | from __future__ import annotations 10 | 11 | from collections.abc import Callable 12 | from typing import Annotated, TypeVar, Union, overload 13 | 14 | from pydantic.functional_validators import AfterValidator 15 | 16 | from albumentations.core.type_definitions import Number 17 | from albumentations.core.utils import to_tuple 18 | 19 | 20 | def nondecreasing(value: tuple[Number, Number]) -> tuple[Number, Number]: 21 | """Ensure a tuple of two numbers is in non-decreasing order. 22 | 23 | Args: 24 | value (tuple[Number, Number]): Tuple of two numeric values to validate. 25 | 26 | Returns: 27 | tuple[Number, Number]: The original tuple if valid. 28 | 29 | Raises: 30 | ValueError: If the first value is greater than the second value. 31 | 32 | """ 33 | if not value[0] <= value[1]: 34 | raise ValueError(f"First value should be less than the second value, got {value} instead") 35 | return value 36 | 37 | 38 | def process_non_negative_range(value: tuple[float, float] | float | None) -> tuple[float, float]: 39 | """Process and validate a non-negative range. 40 | 41 | Args: 42 | value (tuple[float, float] | float | None): Value to process. Can be: 43 | - A tuple of two floats 44 | - A single float (converted to symmetric range) 45 | - None (defaults to 0) 46 | 47 | Returns: 48 | tuple[float, float]: Validated non-negative range. 49 | 50 | Raises: 51 | ValueError: If any values in the range are negative. 52 | 53 | """ 54 | result = to_tuple(value if value is not None else 0, 0) 55 | if not all(x >= 0 for x in result): 56 | msg = "All values in the non negative range should be non negative" 57 | raise ValueError(msg) 58 | return result 59 | 60 | 61 | def float2int(value: tuple[float, float]) -> tuple[int, int]: 62 | """Convert a tuple of floats to a tuple of integers. 63 | 64 | Args: 65 | value (tuple[float, float]): Tuple of two float values. 66 | 67 | Returns: 68 | tuple[int, int]: Tuple of two integer values. 69 | 70 | """ 71 | return int(value[0]), int(value[1]) 72 | 73 | 74 | NonNegativeFloatRangeType = Annotated[ 75 | Union[tuple[float, float], float], 76 | AfterValidator(process_non_negative_range), 77 | AfterValidator(nondecreasing), 78 | ] 79 | 80 | NonNegativeIntRangeType = Annotated[ 81 | Union[tuple[int, int], int], 82 | AfterValidator(process_non_negative_range), 83 | AfterValidator(nondecreasing), 84 | AfterValidator(float2int), 85 | ] 86 | 87 | 88 | @overload 89 | def create_symmetric_range(value: tuple[int, int] | int) -> tuple[int, int]: ... 90 | 91 | 92 | @overload 93 | def create_symmetric_range(value: tuple[float, float] | float) -> tuple[float, float]: ... 94 | 95 | 96 | def create_symmetric_range(value: tuple[float, float] | float) -> tuple[float, float]: 97 | """Create a symmetric range around zero or use provided range. 98 | 99 | Args: 100 | value (tuple[float, float] | float): Input value, either: 101 | - A tuple of two floats (used directly) 102 | - A single float (converted to (-value, value)) 103 | 104 | Returns: 105 | tuple[float, float]: Symmetric range. 106 | 107 | """ 108 | return to_tuple(value) 109 | 110 | 111 | SymmetricRangeType = Annotated[Union[tuple[float, float], float], AfterValidator(create_symmetric_range)] 112 | 113 | 114 | def convert_to_1plus_range(value: tuple[float, float] | float) -> tuple[float, float]: 115 | """Convert value to a range with lower bound of 1. 116 | 117 | Args: 118 | value (tuple[float, float] | float): Input value. 119 | 120 | Returns: 121 | tuple[float, float]: Range with minimum value of at least 1. 122 | 123 | """ 124 | return to_tuple(value, low=1) 125 | 126 | 127 | def convert_to_0plus_range(value: tuple[float, float] | float) -> tuple[float, float]: 128 | """Convert value to a range with lower bound of 0. 129 | 130 | Args: 131 | value (tuple[float, float] | float): Input value. 132 | 133 | Returns: 134 | tuple[float, float]: Range with minimum value of at least 0. 135 | 136 | """ 137 | return to_tuple(value, low=0) 138 | 139 | 140 | def repeat_if_scalar(value: tuple[float, float] | float) -> tuple[float, float]: 141 | """Convert a scalar value to a tuple by repeating it, or return the tuple as is. 142 | 143 | Args: 144 | value (tuple[float, float] | float): Input value, either a scalar or tuple. 145 | 146 | Returns: 147 | tuple[float, float]: If input is scalar, returns (value, value), otherwise returns input unchanged. 148 | 149 | """ 150 | return (value, value) if isinstance(value, (int, float)) else value 151 | 152 | 153 | T = TypeVar("T", int, float) 154 | 155 | 156 | def check_range_bounds( 157 | min_val: Number, 158 | max_val: Number | None = None, 159 | min_inclusive: bool = True, 160 | max_inclusive: bool = True, 161 | ) -> Callable[[tuple[T, ...] | None], tuple[T, ...] | None]: 162 | """Validates that all values in a tuple are within specified bounds. 163 | 164 | Args: 165 | min_val (int | float): 166 | Minimum allowed value. 167 | max_val (int | float | None): 168 | Maximum allowed value. If None, only lower bound is checked. 169 | min_inclusive (bool): 170 | If True, min_val is inclusive (>=). If False, exclusive (>). 171 | max_inclusive (bool): 172 | If True, max_val is inclusive (<=). If False, exclusive (<). 173 | 174 | Returns: 175 | Callable[[tuple[T, ...] | None], tuple[T, ...] | None]: Validator function that 176 | checks if all values in tuple are within bounds. Returns None if input is None. 177 | 178 | Raises: 179 | ValueError: If any value in tuple is outside the allowed range 180 | 181 | Examples: 182 | >>> validator = check_range_bounds(0, 1) # For [0, 1] range 183 | >>> validator((0.1, 0.5)) # Valid 2D 184 | (0.1, 0.5) 185 | >>> validator((0.1, 0.5, 0.7)) # Valid 3D 186 | (0.1, 0.5, 0.7) 187 | >>> validator((1.1, 0.5)) # Raises ValueError - outside range 188 | >>> validator = check_range_bounds(0, 1, max_inclusive=False) # For [0, 1) range 189 | >>> validator((0, 1)) # Raises ValueError - 1 not included 190 | 191 | """ 192 | 193 | def validator(value: tuple[T, ...] | None) -> tuple[T, ...] | None: 194 | if value is None: 195 | return None 196 | 197 | min_op = (lambda x, y: x >= y) if min_inclusive else (lambda x, y: x > y) 198 | max_op = (lambda x, y: x <= y) if max_inclusive else (lambda x, y: x < y) 199 | 200 | if max_val is None: 201 | if not all(min_op(x, min_val) for x in value): 202 | op_symbol = ">=" if min_inclusive else ">" 203 | raise ValueError(f"All values in {value} must be {op_symbol} {min_val}") 204 | else: 205 | min_symbol = ">=" if min_inclusive else ">" 206 | max_symbol = "<=" if max_inclusive else "<" 207 | if not all(min_op(x, min_val) and max_op(x, max_val) for x in value): 208 | raise ValueError(f"All values in {value} must be {min_symbol} {min_val} and {max_symbol} {max_val}") 209 | return value 210 | 211 | return validator 212 | 213 | 214 | ZeroOneRangeType = Annotated[ 215 | Union[tuple[float, float], float], 216 | AfterValidator(convert_to_0plus_range), 217 | AfterValidator(check_range_bounds(0, 1)), 218 | AfterValidator(nondecreasing), 219 | ] 220 | 221 | 222 | OnePlusFloatRangeType = Annotated[ 223 | Union[tuple[float, float], float], 224 | AfterValidator(convert_to_1plus_range), 225 | AfterValidator(check_range_bounds(1, None)), 226 | ] 227 | OnePlusIntRangeType = Annotated[ 228 | Union[tuple[float, float], float], 229 | AfterValidator(convert_to_1plus_range), 230 | AfterValidator(check_range_bounds(1, None)), 231 | AfterValidator(float2int), 232 | ] 233 | 234 | OnePlusIntNonDecreasingRangeType = Annotated[ 235 | tuple[int, int], 236 | AfterValidator(check_range_bounds(1, None)), 237 | AfterValidator(nondecreasing), 238 | AfterValidator(float2int), 239 | ] 240 | -------------------------------------------------------------------------------- /albumentations/core/type_definitions.py: -------------------------------------------------------------------------------- 1 | """Module containing type definitions and constants used throughout Albumentations. 2 | 3 | This module defines common types, constants, and enumerations that are used across the 4 | Albumentations library. It includes type aliases for numeric types, enumerations for 5 | targets supported by transforms, and constants that define standard dimensions or values 6 | used in image and volumetric data processing. These definitions help ensure type safety 7 | and provide a centralized location for commonly used values. 8 | """ 9 | 10 | from __future__ import annotations 11 | 12 | from enum import Enum 13 | from typing import TypeVar, Union 14 | 15 | import cv2 16 | import numpy as np 17 | from albucore.utils import MAX_VALUES_BY_DTYPE 18 | from numpy.typing import NDArray 19 | from typing_extensions import NotRequired, TypedDict 20 | 21 | Number = TypeVar("Number", float, int) 22 | 23 | IntNumType = Union[np.integer, NDArray[np.integer]] 24 | FloatNumType = Union[np.floating, NDArray[np.floating]] 25 | 26 | d4_group_elements = ["e", "r90", "r180", "r270", "v", "hvt", "h", "t"] 27 | 28 | 29 | class ReferenceImage(TypedDict): 30 | """Dictionary-like container for reference image data. 31 | 32 | A typed dictionary defining the structure of reference image data used within 33 | Albumentations, including optional components like masks, bounding boxes, 34 | and keypoints. 35 | 36 | Args: 37 | image (np.ndarray): The reference image array. 38 | mask (np.ndarray | None): Optional mask array. 39 | bbox (tuple[float, ...] | np.ndarray | None): Optional bounding box coordinates. 40 | keypoints (tuple[float, ...] | np.ndarray | None): Optional keypoint coordinates. 41 | 42 | """ 43 | 44 | image: np.ndarray 45 | mask: NotRequired[np.ndarray] 46 | bbox: NotRequired[tuple[float, ...] | np.ndarray] 47 | keypoints: NotRequired[tuple[float, ...] | np.ndarray] 48 | 49 | 50 | class Targets(Enum): 51 | """Enumeration of supported target types in Albumentations. 52 | 53 | This enum defines the different types of data that can be augmented 54 | by Albumentations transforms, including both 2D and 3D targets. 55 | 56 | Args: 57 | IMAGE (str): 2D image target. 58 | MASK (str): 2D mask target. 59 | BBOXES (str): Bounding box target. 60 | KEYPOINTS (str): Keypoint coordinates target. 61 | VOLUME (str): 3D volume target. 62 | MASK3D (str): 3D mask target. 63 | 64 | """ 65 | 66 | IMAGE = "Image" 67 | MASK = "Mask" 68 | BBOXES = "BBoxes" 69 | KEYPOINTS = "Keypoints" 70 | VOLUME = "Volume" 71 | MASK3D = "Mask3D" 72 | 73 | 74 | ALL_TARGETS = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS, Targets.VOLUME, Targets.MASK3D) 75 | 76 | 77 | NUM_VOLUME_DIMENSIONS = 4 78 | NUM_MULTI_CHANNEL_DIMENSIONS = 3 79 | MONO_CHANNEL_DIMENSIONS = 2 80 | NUM_RGB_CHANNELS = 3 81 | 82 | PAIR = 2 83 | TWO = 2 84 | THREE = 3 85 | FOUR = 4 86 | SEVEN = 7 87 | EIGHT = 8 88 | THREE_SIXTY = 360 89 | 90 | BIG_INTEGER = MAX_VALUES_BY_DTYPE[np.uint32] 91 | MAX_RAIN_ANGLE = 45 # Maximum angle for rain augmentation in degrees 92 | 93 | LENGTH_RAW_BBOX = 4 94 | 95 | PercentType = Union[ 96 | float, 97 | tuple[float, float], 98 | tuple[float, float, float, float], 99 | tuple[ 100 | Union[float, tuple[float, float], list[float]], 101 | Union[float, tuple[float, float], list[float]], 102 | Union[float, tuple[float, float], list[float]], 103 | Union[float, tuple[float, float], list[float]], 104 | ], 105 | ] 106 | 107 | 108 | PxType = Union[ 109 | int, 110 | tuple[int, int], 111 | tuple[int, int, int, int], 112 | tuple[ 113 | Union[int, tuple[int, int], list[int]], 114 | Union[int, tuple[int, int], list[int]], 115 | Union[int, tuple[int, int], list[int]], 116 | Union[int, tuple[int, int], list[int]], 117 | ], 118 | ] 119 | 120 | 121 | REFLECT_BORDER_MODES = { 122 | cv2.BORDER_REFLECT_101, 123 | cv2.BORDER_REFLECT, 124 | } 125 | 126 | NUM_KEYPOINTS_COLUMNS_IN_ALBUMENTATIONS = 5 127 | NUM_BBOXES_COLUMNS_IN_ALBUMENTATIONS = 4 128 | -------------------------------------------------------------------------------- /albumentations/core/validation.py: -------------------------------------------------------------------------------- 1 | """Module containing validation mechanisms for transform parameters. 2 | 3 | This module provides a metaclass that enables parameter validation for transforms using 4 | Pydantic models. It intercepts the initialization of transform classes to validate their 5 | parameters against schema definitions, raising appropriate errors for invalid values and 6 | providing type conversion capabilities. This validation layer helps prevent runtime errors 7 | by catching configuration issues at initialization time. 8 | """ 9 | 10 | from __future__ import annotations 11 | 12 | from inspect import Parameter, signature 13 | from typing import Any, Callable 14 | from warnings import warn 15 | 16 | from pydantic import BaseModel, ValidationError 17 | 18 | 19 | class ValidatedTransformMeta(type): 20 | """Metaclass that validates transform parameters during instantiation. 21 | 22 | This metaclass enables automatic validation of transform parameters using Pydantic models, 23 | ensuring proper typing and constraints are enforced before object creation. 24 | 25 | Args: 26 | original_init (Callable[..., Any]): Original __init__ method of the class. 27 | args (tuple[Any, ...]): Positional arguments passed to the __init__ method. 28 | kwargs (dict[str, Any]): Keyword arguments passed to the __init__ method. 29 | 30 | """ 31 | 32 | @staticmethod 33 | def _process_init_parameters( 34 | original_init: Callable[..., Any], 35 | args: tuple[Any, ...], 36 | kwargs: dict[str, Any], 37 | ) -> tuple[dict[str, Any], list[str], bool]: 38 | init_params = signature(original_init).parameters 39 | param_names = list(init_params.keys())[1:] # Exclude 'self' 40 | full_kwargs: dict[str, Any] = dict(zip(param_names, args)) | kwargs 41 | 42 | # Get strict value before validation 43 | strict = full_kwargs.pop("strict", False) 44 | 45 | # Add default values if not provided 46 | for parameter_name, parameter in init_params.items(): 47 | if ( 48 | parameter_name != "self" 49 | and parameter_name not in full_kwargs 50 | and parameter.default is not Parameter.empty 51 | ): 52 | full_kwargs[parameter_name] = parameter.default 53 | 54 | return full_kwargs, param_names, strict 55 | 56 | @staticmethod 57 | def _validate_parameters( 58 | schema_cls: type[BaseModel], 59 | full_kwargs: dict[str, Any], 60 | param_names: list[str], 61 | strict: bool, 62 | ) -> dict[str, Any]: 63 | try: 64 | # Include strict parameter for schema validation 65 | schema_kwargs = {k: v for k, v in full_kwargs.items() if k in param_names} 66 | schema_kwargs["strict"] = strict 67 | config = schema_cls(**schema_kwargs) 68 | validated_kwargs = config.model_dump() 69 | validated_kwargs.pop("strict", None) 70 | except ValidationError as e: 71 | raise ValueError(str(e)) from e 72 | except Exception as e: 73 | if strict: 74 | raise ValueError(str(e)) from e 75 | warn(str(e), stacklevel=2) 76 | return {} 77 | else: 78 | return validated_kwargs 79 | 80 | @staticmethod 81 | def _get_default_values(init_params: dict[str, Parameter]) -> dict[str, Any]: 82 | validated_kwargs = {} 83 | for param_name, param in init_params.items(): 84 | if param_name in {"self", "strict"}: 85 | continue 86 | if param.default is not Parameter.empty: 87 | validated_kwargs[param_name] = param.default 88 | return validated_kwargs 89 | 90 | def __new__(cls: type[Any], name: str, bases: tuple[type, ...], dct: dict[str, Any]) -> type[Any]: 91 | """This is a custom metaclass that validates the parameters of the class during instantiation. 92 | It is used to ensure that the parameters of the class are valid and that they are of the correct type. 93 | """ 94 | if "InitSchema" in dct and issubclass(dct["InitSchema"], BaseModel): 95 | original_init: Callable[..., Any] | None = dct.get("__init__") 96 | if original_init is None: 97 | msg = "__init__ not found in class definition" 98 | raise ValueError(msg) 99 | 100 | original_sig = signature(original_init) 101 | 102 | def custom_init(self: Any, *args: Any, **kwargs: Any) -> None: 103 | full_kwargs, param_names, strict = cls._process_init_parameters(original_init, args, kwargs) 104 | 105 | validated_kwargs = cls._validate_parameters( 106 | dct["InitSchema"], 107 | full_kwargs, 108 | param_names, 109 | strict, 110 | ) or cls._get_default_values(signature(original_init).parameters) 111 | 112 | # Store and check invalid args 113 | invalid_args = [name_arg for name_arg in kwargs if name_arg not in param_names and name_arg != "strict"] 114 | original_init(self, **validated_kwargs) 115 | self.invalid_args = invalid_args 116 | 117 | if invalid_args: 118 | message = f"Argument(s) '{', '.join(invalid_args)}' are not valid for transform {name}" 119 | if strict: 120 | raise ValueError(message) 121 | warn(message, stacklevel=2) 122 | 123 | # Preserve the original signature and docstring 124 | custom_init.__signature__ = original_sig # type: ignore[attr-defined] 125 | custom_init.__doc__ = original_init.__doc__ 126 | 127 | dct["__init__"] = custom_init 128 | 129 | return super().__new__(cls, name, bases, dct) 130 | -------------------------------------------------------------------------------- /albumentations/pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import * 2 | -------------------------------------------------------------------------------- /albumentations/pytorch/transforms.py: -------------------------------------------------------------------------------- 1 | """Module containing PyTorch-specific transforms for Albumentations. 2 | 3 | This module provides transforms that convert NumPy arrays to PyTorch tensors in 4 | the appropriate format. It handles both 2D image data and 3D volumetric data, 5 | ensuring that the tensor dimensions are correctly arranged according to PyTorch's 6 | expected format (channels first). These transforms are typically used as the final 7 | step in an augmentation pipeline before feeding data to a PyTorch model. 8 | """ 9 | 10 | from __future__ import annotations 11 | 12 | from typing import Any, overload 13 | 14 | import numpy as np 15 | import torch 16 | 17 | from albumentations.core.transforms_interface import BasicTransform 18 | from albumentations.core.type_definitions import ( 19 | MONO_CHANNEL_DIMENSIONS, 20 | NUM_MULTI_CHANNEL_DIMENSIONS, 21 | NUM_VOLUME_DIMENSIONS, 22 | Targets, 23 | ) 24 | 25 | __all__ = ["ToTensor3D", "ToTensorV2"] 26 | 27 | 28 | class ToTensorV2(BasicTransform): 29 | """Converts images/masks to PyTorch Tensors, inheriting from BasicTransform. 30 | For images: 31 | - If input is in `HWC` format, converts to PyTorch `CHW` format 32 | - If input is in `HW` format, converts to PyTorch `1HW` format (adds channel dimension) 33 | 34 | Attributes: 35 | transpose_mask (bool): If True, transposes 3D input mask dimensions from `[height, width, num_channels]` to 36 | `[num_channels, height, width]`. 37 | p (float): Probability of applying the transform. Default: 1.0. 38 | 39 | """ 40 | 41 | _targets = (Targets.IMAGE, Targets.MASK) 42 | 43 | def __init__(self, transpose_mask: bool = False, p: float = 1.0): 44 | super().__init__(p=p) 45 | self.transpose_mask = transpose_mask 46 | 47 | @property 48 | def targets(self) -> dict[str, Any]: 49 | """Define mapping of target name to target function. 50 | 51 | Returns: 52 | dict[str, Any]: Dictionary mapping target names to corresponding transform functions. 53 | 54 | """ 55 | return { 56 | "image": self.apply, 57 | "images": self.apply_to_images, 58 | "mask": self.apply_to_mask, 59 | "masks": self.apply_to_masks, 60 | } 61 | 62 | def apply(self, img: np.ndarray, **params: Any) -> torch.Tensor: 63 | """Convert a 2D image array to a PyTorch tensor. 64 | 65 | Converts image from HWC or HW format to CHW format, handling both 66 | single-channel and multi-channel images. 67 | 68 | Args: 69 | img (np.ndarray): Image as a numpy array of shape (H,W) or (H,W,C) 70 | **params (Any): Additional parameters 71 | 72 | Returns: 73 | torch.Tensor: PyTorch tensor in CHW format 74 | 75 | Raises: 76 | ValueError: If image dimensions are neither HW nor HWC 77 | 78 | """ 79 | if img.ndim not in {MONO_CHANNEL_DIMENSIONS, NUM_MULTI_CHANNEL_DIMENSIONS}: 80 | msg = "Albumentations only supports images in HW or HWC format" 81 | raise ValueError(msg) 82 | 83 | if img.ndim == MONO_CHANNEL_DIMENSIONS: 84 | img = np.expand_dims(img, 2) 85 | 86 | return torch.from_numpy(img.transpose(2, 0, 1)) 87 | 88 | def apply_to_mask(self, mask: np.ndarray, **params: Any) -> torch.Tensor: 89 | """Convert a mask array to a PyTorch tensor. 90 | 91 | If transpose_mask is True and mask has 3 dimensions (H,W,C), 92 | converts mask to channels-first format (C,H,W). 93 | 94 | Args: 95 | mask (np.ndarray): Mask as a numpy array 96 | **params (Any): Additional parameters 97 | 98 | Returns: 99 | torch.Tensor: PyTorch tensor of mask 100 | 101 | """ 102 | if self.transpose_mask and mask.ndim == NUM_MULTI_CHANNEL_DIMENSIONS: 103 | mask = mask.transpose(2, 0, 1) 104 | return torch.from_numpy(mask) 105 | 106 | @overload 107 | def apply_to_masks(self, masks: list[np.ndarray], **params: Any) -> list[torch.Tensor]: ... 108 | 109 | @overload 110 | def apply_to_masks(self, masks: np.ndarray, **params: Any) -> torch.Tensor: ... 111 | 112 | def apply_to_masks(self, masks: np.ndarray | list[np.ndarray], **params: Any) -> torch.Tensor | list[torch.Tensor]: 113 | """Convert numpy array or list of numpy array masks to torch tensor(s). 114 | 115 | Args: 116 | masks (np.ndarray | list[np.ndarray]): Numpy array of shape (N, H, W) or (N, H, W, C), 117 | or a list of numpy arrays with shape (H, W) or (H, W, C). 118 | **params (Any): Additional parameters. 119 | 120 | Returns: 121 | torch.Tensor | list[torch.Tensor]: If transpose_mask is True and input is (N, H, W, C), 122 | returns tensor of shape (N, C, H, W). If transpose_mask is True and input is (H, W, C), r 123 | eturns a list of tensors with shape (C, H, W). Otherwise, returns tensors with the same shape as input. 124 | 125 | """ 126 | if isinstance(masks, list): 127 | return [self.apply_to_mask(mask, **params) for mask in masks] 128 | 129 | if self.transpose_mask and masks.ndim == NUM_VOLUME_DIMENSIONS: # (N, H, W, C) 130 | masks = np.transpose(masks, (0, 3, 1, 2)) # -> (N, C, H, W) 131 | return torch.from_numpy(masks) 132 | 133 | def apply_to_images(self, images: np.ndarray, **params: Any) -> torch.Tensor: 134 | """Convert batch of images from (N, H, W, C) to (N, C, H, W).""" 135 | if images.ndim != NUM_VOLUME_DIMENSIONS: # N,H,W,C 136 | raise ValueError(f"Expected 4D array (N,H,W,C), got {images.ndim}D array") 137 | return torch.from_numpy(images.transpose(0, 3, 1, 2)) # -> (N,C,H,W) 138 | 139 | 140 | class ToTensor3D(BasicTransform): 141 | """Convert 3D volumes and masks to PyTorch tensors. 142 | 143 | This transform is designed for 3D medical imaging data. It converts numpy arrays 144 | to PyTorch tensors and ensures consistent channel positioning. 145 | 146 | For all inputs (volumes and masks): 147 | - Input: (D, H, W, C) or (D, H, W) - depth, height, width, [channels] 148 | - Output: (C, D, H, W) - channels first format for PyTorch 149 | For single-channel input, adds C=1 dimension 150 | 151 | Note: 152 | This transform always moves channels to first position as this is 153 | the standard PyTorch format. For masks that need to stay in DHWC format, 154 | use a different transform or handle the transposition after this transform. 155 | 156 | Args: 157 | p (float): Probability of applying the transform. Default: 1.0 158 | 159 | """ 160 | 161 | _targets = (Targets.IMAGE, Targets.MASK) 162 | 163 | def __init__(self, p: float = 1.0): 164 | super().__init__(p=p) 165 | 166 | @property 167 | def targets(self) -> dict[str, Any]: 168 | """Define mapping of target name to target function. 169 | 170 | Returns: 171 | dict[str, Any]: Dictionary mapping target names to corresponding transform functions 172 | 173 | """ 174 | return { 175 | "volume": self.apply_to_volume, 176 | "mask3d": self.apply_to_mask3d, 177 | } 178 | 179 | def apply_to_volume(self, volume: np.ndarray, **params: Any) -> torch.Tensor: 180 | """Convert 3D volume to channels-first tensor.""" 181 | if volume.ndim == NUM_VOLUME_DIMENSIONS: # D,H,W,C 182 | return torch.from_numpy(volume.transpose(3, 0, 1, 2)) 183 | if volume.ndim == NUM_VOLUME_DIMENSIONS - 1: # D,H,W 184 | return torch.from_numpy(volume[np.newaxis, ...]) 185 | raise ValueError(f"Expected 3D or 4D array (D,H,W) or (D,H,W,C), got {volume.ndim}D array") 186 | 187 | def apply_to_mask3d(self, mask3d: np.ndarray, **params: Any) -> torch.Tensor: 188 | """Convert 3D mask to channels-first tensor.""" 189 | return self.apply_to_volume(mask3d, **params) 190 | -------------------------------------------------------------------------------- /conda.recipe/build_upload.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # How to generate ANACONDA_TOKEN: https://docs.anaconda.com/anaconda-cloud/user-guide/tasks/work-with-accounts#creating-access-tokens 4 | if [ -z "$ANACONDA_TOKEN" ]; then 5 | echo "ANACONDA_TOKEN is unset. Please set it in your environment before running this script"; 6 | exit 1 7 | fi 8 | 9 | conda install -y conda-build conda-verify anaconda-client 10 | conda config --set anaconda_upload no 11 | conda build --quiet --no-test --output-folder conda_build conda.recipe 12 | 13 | # Convert to other platforms: OSX, WIN 14 | conda convert --platform win-64 conda_build/linux-64/*.tar.bz2 -o conda_build/ 15 | conda convert --platform osx-64 conda_build/linux-64/*.tar.bz2 -o conda_build/ 16 | 17 | # Upload to Anaconda 18 | # We could use --all but too many platforms to upload 19 | ls conda_build/*/*.tar.bz2 | xargs -I {} anaconda -v -t $ANACONDA_TOKEN upload -u albumentations {} 20 | -------------------------------------------------------------------------------- /conda.recipe/conda_build_config.yaml: -------------------------------------------------------------------------------- 1 | python: 2 | - 3.9 3 | - 3.10 4 | - 3.11 5 | - 3.12 6 | - 3.13 7 | -------------------------------------------------------------------------------- /conda.recipe/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set data = load_setup_py_data() %} 2 | 3 | package: 4 | name: albumentations 5 | version: {{ data['version'] }} 6 | 7 | source: 8 | path: .. 9 | 10 | build: 11 | number: 0 12 | script: python -m pip install . --no-deps --ignore-installed --no-cache-dir -vvv 13 | 14 | requirements: 15 | host: 16 | - python 17 | - pip 18 | - setuptools 19 | 20 | build: 21 | - python 22 | - numpy>=1.24.4 23 | - scipy 24 | - pydantic>=2.9 25 | - pytorch 26 | - typing_extensions 27 | - opencv-python-headless 28 | 29 | run: 30 | - python 31 | - numpy>=1.24.4 32 | - scipy>=1.10.0 33 | - pyyaml 34 | - pydantic>=2.9.2 35 | - typing-extensions>=4.9.0 36 | - opencv-python-headless>=4.9.0.80 37 | - albucore==0.0.22 38 | 39 | suggests: 40 | - pytorch 41 | - huggingface-hub 42 | - pillow 43 | 44 | test: 45 | imports: 46 | - albumentations 47 | 48 | about: 49 | home: {{ data['url'] }} 50 | license: {{ data['license'] }} 51 | summary: {{ data['description'] }} 52 | -------------------------------------------------------------------------------- /docs/contributing/environment_setup.md: -------------------------------------------------------------------------------- 1 | # Setting Up Your Development Environment 2 | 3 | This guide will help you set up your development environment for contributing to Albumentations. 4 | 5 | ## Prerequisites 6 | 7 | - Python 3.9 or higher 8 | - Git 9 | - A GitHub account 10 | 11 | ## Step-by-Step Setup 12 | 13 | ### 1. Fork and Clone the Repository 14 | 15 | 1. Fork the [Albumentations repository](https://github.com/albumentations-team/albumentations) on GitHub 16 | 2. Clone your fork locally: 17 | 18 | ```bash 19 | git clone https://github.com/YOUR_USERNAME/albumentations.git 20 | cd albumentations 21 | ``` 22 | 23 | ### 2. Create a Virtual Environment 24 | 25 | Choose the appropriate commands for your operating system: 26 | 27 | #### Linux / macOS 28 | 29 | ```bash 30 | python3 -m venv env 31 | source env/bin/activate 32 | ``` 33 | 34 | #### Windows (cmd.exe) 35 | 36 | ```bash 37 | python -m venv env 38 | env\Scripts\activate.bat 39 | ``` 40 | 41 | #### Windows (PowerShell) 42 | 43 | ```bash 44 | python -m venv env 45 | env\Scripts\activate.ps1 46 | ``` 47 | 48 | ### 3. Install Dependencies 49 | 50 | 1. Install the project in editable mode: 51 | 52 | ```bash 53 | pip install -e . 54 | ``` 55 | 56 | 1. Install development dependencies: 57 | 58 | ```bash 59 | pip install -r requirements-dev.txt 60 | ``` 61 | 62 | ### 4. Set Up Pre-commit Hooks 63 | 64 | Pre-commit hooks help maintain code quality by automatically checking your changes before each commit. 65 | 66 | 1. Install pre-commit: 67 | 68 | ```bash 69 | pip install pre-commit 70 | ``` 71 | 72 | 1. Set up the hooks: 73 | 74 | ```bash 75 | pre-commit install 76 | ``` 77 | 78 | 1. (Optional) Run hooks manually on all files: 79 | 80 | ```bash 81 | pre-commit run --files $(find albumentations -type f) 82 | ``` 83 | 84 | ## Verifying Your Setup 85 | 86 | ### Run Tests 87 | 88 | Ensure everything is set up correctly by running the test suite: 89 | 90 | ```bash 91 | pytest 92 | ``` 93 | 94 | ### Common Issues and Solutions 95 | 96 | #### Permission Errors 97 | 98 | - **Linux/macOS**: If you encounter permission errors, try using `sudo` for system-wide installations or consider using `--user` flag with pip 99 | - **Windows**: Run your terminal as administrator if you encounter permission issues 100 | 101 | #### Virtual Environment Not Activating 102 | 103 | - Ensure you're in the correct directory 104 | - Check that Python is properly installed and in your system PATH 105 | - Try creating the virtual environment with the full Python path 106 | 107 | #### Import Errors After Installation 108 | 109 | - Verify that you're using the correct virtual environment 110 | - Confirm that all dependencies were installed successfully 111 | - Try reinstalling the package in editable mode 112 | 113 | ## Next Steps 114 | 115 | After setting up your environment: 116 | 117 | 1. Create a new branch for your work 118 | 2. Make your changes 119 | 3. Run tests and pre-commit hooks 120 | 4. Submit a pull request 121 | 122 | For more detailed information about contributing, please refer to [Coding Guidelines](./coding_guidelines.md) 123 | 124 | ## Getting Help 125 | 126 | If you encounter any issues with the setup: 127 | 128 | 1. Check our [Discord community](https://discord.gg/e6zHCXTvaN) 129 | 2. Open an [issue on GitHub](https://github.com/albumentations-team/albumentations/issues) 130 | 3. Review existing issues for similar problems and solutions 131 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "setuptools.build_meta" 3 | 4 | requires = [ "setuptools>=45", "wheel" ] 5 | 6 | [project] 7 | name = "albumentations" 8 | 9 | version = "2.0.8" 10 | 11 | description = "Fast, flexible, and advanced augmentation library for deep learning, computer vision, and medical imaging. Albumentations offers a wide range of transformations for both 2D (images, masks, bboxes, keypoints) and 3D (volumes, volumetric masks, keypoints) data, with optimized performance and seamless integration into ML workflows." 12 | readme = "README.md" 13 | keywords = [ 14 | "2D augmentation", 15 | "3D augmentation", 16 | "aerial photography", 17 | "anomaly detection", 18 | 19 | "artificial intelligence", 20 | 21 | "autonomous driving", 22 | "bounding boxes", 23 | # Core Computer Vision Tasks 24 | "classification", 25 | # Technical Domains 26 | "computer vision", 27 | "computer vision library", 28 | "data augmentation", 29 | 30 | "data preprocessing", 31 | "data science", 32 | "deep learning", 33 | "deep learning library", 34 | 35 | "depth estimation", 36 | "face recognition", 37 | # Performance & Features 38 | "fast augmentation", 39 | # Data Types & Processing 40 | "image augmentation", 41 | "image processing", 42 | "image transformation", 43 | # Data Structures 44 | "images", 45 | "instance segmentation", 46 | "keras", 47 | "keypoint detection", 48 | "keypoints", 49 | "machine learning", 50 | "machine learning tools", 51 | "masks", 52 | # Application Domains 53 | "medical imaging", 54 | "microscopy", 55 | "object counting", 56 | "object detection", 57 | "optimized performance", 58 | "panoptic segmentation", 59 | "pose estimation", 60 | # Development 61 | "python library", 62 | # ML Frameworks 63 | "pytorch", 64 | "quality inspection", 65 | 66 | "real-time processing", 67 | 68 | "robotics vision", 69 | "satellite imagery", 70 | "semantic segmentation", 71 | "tensorflow", 72 | "volumes", 73 | "volumetric data", 74 | "volumetric masks", 75 | 76 | ] 77 | license = { file = "LICENSE" } 78 | 79 | maintainers = [ { name = "Vladimir Iglovikov" } ] 80 | 81 | authors = [ { name = "Vladimir Iglovikov" } ] 82 | requires-python = ">=3.9" 83 | 84 | classifiers = [ 85 | # Development Status 86 | "Development Status :: 5 - Production/Stable", 87 | 88 | # Intended Audience 89 | "Intended Audience :: Developers", 90 | "Intended Audience :: Healthcare Industry", # valid for medical applications 91 | "Intended Audience :: Information Technology", 92 | 93 | "Intended Audience :: Science/Research", 94 | # License 95 | "License :: OSI Approved :: MIT License", 96 | 97 | # Operating System 98 | "Operating System :: OS Independent", 99 | 100 | # Python Versions 101 | "Programming Language :: Python", 102 | "Programming Language :: Python :: 3 :: Only", 103 | "Programming Language :: Python :: 3.9", 104 | "Programming Language :: Python :: 3.10", 105 | "Programming Language :: Python :: 3.11", 106 | "Programming Language :: Python :: 3.12", 107 | 108 | "Programming Language :: Python :: 3.13", 109 | # Topics - Scientific 110 | "Topic :: Scientific/Engineering", 111 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 112 | "Topic :: Scientific/Engineering :: Astronomy", 113 | "Topic :: Scientific/Engineering :: Atmospheric Science", 114 | 115 | "Topic :: Scientific/Engineering :: Bio-Informatics", 116 | "Topic :: Scientific/Engineering :: Image Processing", 117 | "Topic :: Scientific/Engineering :: Physics", 118 | "Topic :: Scientific/Engineering :: Visualization", 119 | # Topics - Software Development 120 | "Topic :: Software Development :: Libraries", 121 | "Topic :: Software Development :: Libraries :: Python Modules", 122 | 123 | # Typing 124 | "Typing :: Typed", 125 | ] 126 | 127 | dynamic = [ "dependencies" ] 128 | optional-dependencies.hub = [ "huggingface-hub" ] 129 | optional-dependencies.pytorch = [ "torch" ] 130 | 131 | optional-dependencies.text = [ "pillow" ] 132 | urls.Homepage = "https://albumentations.ai" 133 | 134 | [tool.setuptools] 135 | packages = { find = { include = [ 136 | "albumentations*", 137 | ], exclude = [ 138 | "tests", 139 | "tools", 140 | "benchmark", 141 | "docs", 142 | ".github", 143 | ".cursor", 144 | ] } } 145 | 146 | package-data = { albumentations = [ "*.txt", "*.md" ] } 147 | 148 | [tool.setuptools.exclude-package-data] 149 | "*" = [ "tests*", "tools*", "benchmark*", "conda.recipe*", "docs*", ".github*", ".cursor" ] 150 | 151 | [tool.ruff] 152 | # Exclude a variety of commonly ignored directories. 153 | target-version = "py39" 154 | 155 | line-length = 120 156 | indent-width = 4 157 | 158 | # Assume Python 3.9 159 | exclude = [ 160 | ".bzr", 161 | ".direnv", 162 | ".eggs", 163 | ".git", 164 | ".git-rewrite", 165 | ".hg", 166 | ".ipynb_checkpoints", 167 | ".mypy_cache", 168 | ".nox", 169 | ".pants.d", 170 | ".pyenv", 171 | ".pytest_cache", 172 | ".pytype", 173 | ".ruff_cache", 174 | ".svn", 175 | ".tox", 176 | ".venv", 177 | ".vscode", 178 | "__pypackages__", 179 | "_build", 180 | "buck-out", 181 | "build", 182 | "dist", 183 | "node_modules", 184 | "setup.py", 185 | "site", 186 | "site-packages", 187 | "tests", 188 | "venv", 189 | ] 190 | 191 | format.indent-style = "space" 192 | # Like Black, respect magic trailing commas. 193 | format.quote-style = "double" 194 | # Like Black, indent with spaces, rather than tabs. 195 | format.line-ending = "auto" 196 | format.skip-magic-trailing-comma = false 197 | # Like Black, automatically detect the appropriate line ending. 198 | lint.select = [ "ALL" ] 199 | lint.ignore = [ 200 | "ANN001", 201 | "ANN204", 202 | "ANN401", 203 | "ARG001", 204 | "ARG002", 205 | "B006", 206 | "B008", 207 | "B027", 208 | "D104", 209 | "D105", 210 | "D106", 211 | "D107", 212 | "D205", 213 | "D213", 214 | "D400", 215 | "D401", 216 | "D404", 217 | "D415", 218 | "EM101", 219 | "EM102", 220 | "F403", 221 | "FBT001", 222 | "FBT002", 223 | "FBT003", 224 | "G004", 225 | "PLR0911", 226 | "PLR0913", 227 | "PLR2004", 228 | "S311", 229 | "S608", 230 | "TC001", 231 | "TC002", 232 | "TC003", 233 | "TRY003", 234 | ] 235 | 236 | lint.per-file-ignores."tools/*" = [ 237 | "ANN201", 238 | "D100", 239 | "D101", 240 | "D103", 241 | "INP001", 242 | "SLF001", 243 | "T201", 244 | ] 245 | 246 | [tool.mypy] 247 | plugins = [ "pydantic.mypy" ] 248 | 249 | python_version = "3.9" 250 | ignore_missing_imports = true 251 | follow_imports = "silent" 252 | warn_redundant_casts = true 253 | warn_unused_ignores = true 254 | disallow_any_generics = true 255 | check_untyped_defs = true 256 | no_implicit_reexport = true 257 | disable_error_code = [ "valid-type" ] 258 | 259 | # for strict mypy: (this is the tricky one :-)) 260 | disallow_untyped_defs = true 261 | 262 | [tool.pydocstyle] 263 | # Allow fix for all enabled rules (when `--fix`) is provided. 264 | 265 | lint.explicit-preview-rules = true 266 | lint.fixable = [ "ALL" ] 267 | lint.unfixable = [ ] 268 | # Allow unused variables when underscore-prefixed. 269 | lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 270 | # Like Black, use double quotes for strings. 271 | lint.pydocstyle.convention = "google" 272 | lint.pydocstyle.ignore-magic-methods = true 273 | 274 | [tool.albumentations.maintainers] 275 | emeritus = [ 276 | "Alexander Buslaev", 277 | "Alex Parinov", 278 | "Eugene Khvedchenya", 279 | "Mikhail Druzhinin", 280 | ] 281 | 282 | [tool.pydantic-mypy] 283 | init_forbid_extra = true 284 | init_typed = true 285 | warn_required_dynamic_aliases = true 286 | 287 | [tool.codeflash] 288 | # All paths are relative to this pyproject.toml's directory. 289 | module-root = "albumentations" 290 | tests-root = "tests" 291 | test-framework = "pytest" 292 | ignore-paths = [ ] 293 | formatter-cmds = [ "ruff check --exit-zero --fix $file", "ruff format $file" ] 294 | 295 | [tool.google_docstring_parser] 296 | paths = [ "albumentations", "tools" ] # Directories or files to scan 297 | require_param_types = true # Require parameter types in docstrings 298 | exclude_files = [ "__init__.py" ] # Files to exclude from checks 299 | verbose = false # Enable verbose output 300 | check_references = true 301 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | deepdiff>=8.0.1 2 | eval-type-backport 3 | pre_commit>=3.5.0 4 | pytest>=8.3.3 5 | pytest_cov>=5.0.0 6 | pytest_mock>=3.14.0 7 | pytz 8 | requests>=2.31.0 9 | scikit-image 10 | scikit-learn 11 | tomli>=2.0.1 12 | torch>=2.3.1 13 | torchvision>=0.18.1 14 | types-PyYAML 15 | types-setuptools 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | from pkg_resources import DistributionNotFound, get_distribution 3 | from setuptools import setup, find_packages 4 | 5 | INSTALL_REQUIRES = [ 6 | "numpy>=1.24.4", 7 | "scipy>=1.10.0", 8 | "PyYAML", 9 | "typing-extensions>=4.9.0; python_version<'3.10'", 10 | "pydantic>=2.9.2", 11 | "albucore==0.0.24", 12 | "eval-type-backport; python_version<'3.10'", 13 | ] 14 | 15 | MIN_OPENCV_VERSION = "4.9.0.80" 16 | 17 | # OpenCV packages in order of preference 18 | OPENCV_PACKAGES = [ 19 | f"opencv-python>={MIN_OPENCV_VERSION}", 20 | f"opencv-contrib-python>={MIN_OPENCV_VERSION}", 21 | f"opencv-contrib-python-headless>={MIN_OPENCV_VERSION}", 22 | f"opencv-python-headless>={MIN_OPENCV_VERSION}", 23 | ] 24 | 25 | def is_installed(package_name: str) -> bool: 26 | try: 27 | get_distribution(package_name) 28 | return True 29 | except DistributionNotFound: 30 | return False 31 | 32 | def choose_opencv_requirement(): 33 | """Check if any OpenCV package is already installed and use that one.""" 34 | # First try to import cv2 to see if any OpenCV is installed 35 | try: 36 | import cv2 37 | 38 | # Try to determine which package provides the installed cv2 39 | for package in OPENCV_PACKAGES: 40 | package_name = re.split(r"[!<>=]", package)[0].strip() 41 | if is_installed(package_name): 42 | return package 43 | 44 | # If we can import cv2 but can't determine the package, 45 | # don't add any OpenCV requirement 46 | return None 47 | 48 | except ImportError: 49 | # No OpenCV installed, use the headless version as default 50 | return f"opencv-python-headless>={MIN_OPENCV_VERSION}" 51 | 52 | # Add OpenCV requirement if needed 53 | if opencv_req := choose_opencv_requirement(): 54 | INSTALL_REQUIRES.append(opencv_req) 55 | 56 | setup( 57 | packages=find_packages(exclude=["tests", "tools", "benchmark"], include=['albumentations*']), 58 | install_requires=INSTALL_REQUIRES, 59 | ) 60 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/albumentations-team/albumentations/46cc4efd69df585dc355ba0b92b0faf6122e73a6/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import sys 3 | 4 | import numpy as np 5 | import pytest 6 | import cv2 7 | 8 | cv2.setRNGSeed(137) 9 | 10 | np.random.seed(137) 11 | 12 | @pytest.fixture 13 | def mask(): 14 | return cv2.randu(np.empty((100, 100), dtype=np.uint8), 0, 2) 15 | 16 | @pytest.fixture 17 | def image(): 18 | return cv2.randu(np.zeros((100, 100, 3), dtype=np.uint8), 19 | low=np.array([0, 0, 0]), 20 | high=np.array([255, 255, 255])) 21 | 22 | 23 | @pytest.fixture 24 | def bboxes(): 25 | return np.array([[15, 12, 75, 30, 1], [55, 25, 90, 90, 2]]) 26 | 27 | @pytest.fixture 28 | def volume(): 29 | return np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) 30 | 31 | @pytest.fixture 32 | def mask3d(): 33 | return np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) 34 | 35 | 36 | @pytest.fixture 37 | def albumentations_bboxes(): 38 | return np.array([[0.15, 0.12, 0.75, 0.30, 1], [0.55, 0.25, 0.90, 0.90, 2]]) 39 | 40 | 41 | @pytest.fixture 42 | def keypoints(): 43 | return np.array([[30, 20, 0, 0.5, 1], [20, 30, 60, 2.5, 2]], dtype=np.float32) 44 | 45 | 46 | @pytest.fixture 47 | def template(): 48 | return cv2.randu(np.zeros((100, 100, 3), dtype=np.uint8), 0, 255) 49 | 50 | 51 | @pytest.fixture 52 | def float_template(): 53 | return cv2.randu(np.zeros((100, 100, 3), dtype=np.float32), 0, 1) 54 | 55 | 56 | @pytest.fixture(scope="package") 57 | def mp_pool(): 58 | # Usage of `fork` as a start method for multiprocessing could lead to deadlocks on macOS. 59 | # Because `fork` was the default start method for macOS until Python 3.8 60 | # we had to manually set the start method to `spawn` to avoid those issues. 61 | if sys.platform == "darwin": 62 | method = "spawn" 63 | else: 64 | method = None 65 | return multiprocessing.get_context(method).Pool(4) 66 | 67 | SQUARE_UINT8_IMAGE = cv2.randu(np.zeros((100, 100, 3), dtype=np.uint8), 0, 255) 68 | RECTANGULAR_UINT8_IMAGE = cv2.randu(np.zeros((101, 99, 3), dtype=np.uint8), 0, 255) 69 | 70 | SQUARE_FLOAT_IMAGE = cv2.randu(np.zeros((100, 100, 3), dtype=np.float32), 0, 1) 71 | RECTANGULAR_FLOAT_IMAGE = cv2.randu(np.zeros((101, 99, 3), dtype=np.float32), 0, 1) 72 | 73 | UINT8_IMAGES = [SQUARE_UINT8_IMAGE, RECTANGULAR_UINT8_IMAGE] 74 | 75 | FLOAT32_IMAGES = [SQUARE_FLOAT_IMAGE, RECTANGULAR_FLOAT_IMAGE] 76 | 77 | IMAGES = UINT8_IMAGES + FLOAT32_IMAGES 78 | 79 | SQUARE_IMAGES = [SQUARE_UINT8_IMAGE, SQUARE_FLOAT_IMAGE] 80 | RECTANGULAR_IMAGES = [RECTANGULAR_UINT8_IMAGE, RECTANGULAR_FLOAT_IMAGE] 81 | 82 | SQUARE_MULTI_UINT8_IMAGE = np.random.randint(low=0, high=256, size=(100, 100, 5), dtype=np.uint8) 83 | SQUARE_MULTI_FLOAT_IMAGE = np.random.uniform(low=0.0, high=1.0, size=(100, 100, 5)).astype(np.float32) 84 | 85 | MULTI_IMAGES = [SQUARE_MULTI_UINT8_IMAGE, SQUARE_MULTI_FLOAT_IMAGE] 86 | -------------------------------------------------------------------------------- /tests/files/LiberationSerif-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/albumentations-team/albumentations/46cc4efd69df585dc355ba0b92b0faf6122e73a6/tests/files/LiberationSerif-Bold.ttf -------------------------------------------------------------------------------- /tests/test_blur.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Any 3 | import warnings 4 | import numpy as np 5 | import pytest 6 | 7 | from PIL import Image, ImageFilter 8 | import cv2 9 | 10 | import albumentations as A 11 | from albumentations.augmentations.blur import functional as fblur 12 | 13 | from albumentations.core.transforms_interface import BasicTransform 14 | from tests.conftest import UINT8_IMAGES 15 | 16 | 17 | @pytest.mark.parametrize("aug", [A.Blur, A.MedianBlur, A.MotionBlur]) 18 | @pytest.mark.parametrize( 19 | "blur_limit_input, blur_limit_used", 20 | [[(3, 3), (3, 3)], [(13, 13), (13, 13)]], 21 | ) 22 | @pytest.mark.parametrize("image", UINT8_IMAGES) 23 | def test_blur_kernel_generation( 24 | image: np.ndarray, 25 | aug: BasicTransform, 26 | blur_limit_input: tuple[int, int], 27 | blur_limit_used: tuple[int, int], 28 | ) -> None: 29 | aug = aug(blur_limit=blur_limit_input, p=1) 30 | 31 | assert aug.blur_limit == blur_limit_used 32 | aug(image=image)["image"] 33 | 34 | 35 | @pytest.mark.parametrize("val_uint8", [0, 1, 128, 255]) 36 | def test_glass_blur_float_uint8_diff_less_than_two(val_uint8: list[int]) -> None: 37 | x_uint8 = np.zeros((5, 5)).astype(np.uint8) 38 | x_uint8[2, 2] = val_uint8 39 | 40 | x_float32 = np.zeros((5, 5)).astype(np.float32) 41 | x_float32[2, 2] = val_uint8 / 255.0 42 | 43 | glassblur = A.GlassBlur(p=1, max_delta=1) 44 | glassblur.random_generator = np.random.default_rng(0) 45 | 46 | blur_uint8 = glassblur(image=x_uint8)["image"] 47 | glassblur.random_generator = np.random.default_rng(0) 48 | 49 | blur_float32 = glassblur(image=x_float32)["image"] 50 | 51 | # Before comparison, rescale the blur_float32 to [0, 255] 52 | diff = np.abs(blur_uint8 - blur_float32 * 255) 53 | 54 | # The difference between the results of float32 and uint8 will be at most 2. 55 | assert np.all(diff <= 2.0) 56 | 57 | 58 | @pytest.mark.parametrize("val_uint8", [0, 1, 128, 255]) 59 | def test_advanced_blur_float_uint8_diff_less_than_two(val_uint8: list[int]) -> None: 60 | x_uint8 = np.zeros((5, 5)).astype(np.uint8) 61 | x_uint8[2, 2] = val_uint8 62 | 63 | x_float32 = np.zeros((5, 5)).astype(np.float32) 64 | x_float32[2, 2] = val_uint8 / 255.0 65 | 66 | adv_blur = A.AdvancedBlur(blur_limit=(3, 5), p=1) 67 | adv_blur.set_random_seed(0) 68 | 69 | adv_blur_uint8 = adv_blur(image=x_uint8)["image"] 70 | 71 | adv_blur.set_random_seed(0) 72 | adv_blur_float32 = adv_blur(image=x_float32)["image"] 73 | 74 | # Before comparison, rescale the adv_blur_float32 to [0, 255] 75 | diff = np.abs(adv_blur_uint8 - adv_blur_float32 * 255) 76 | 77 | # The difference between the results of float32 and uint8 will be at most 2. 78 | assert np.all(diff <= 2.0) 79 | 80 | 81 | @pytest.mark.parametrize( 82 | "params", 83 | [ 84 | {"sigma_x_limit": (0.0, 1.0), "sigma_y_limit": (0.0, 1.0)}, 85 | {"beta_limit": (0.1, 0.9)}, 86 | {"beta_limit": (1.1, 8.0)}, 87 | ], 88 | ) 89 | def test_advanced_blur_raises_on_incorrect_params( 90 | params: dict[str, list[float]], 91 | ) -> None: 92 | with pytest.raises(ValueError): 93 | A.AdvancedBlur(**params) 94 | 95 | class MockValidationInfo: 96 | def __init__(self, field_name: str): 97 | self.field_name = field_name 98 | 99 | 100 | @pytest.mark.parametrize( 101 | ["value", "min_value", "expected", "warning_messages"], 102 | [ 103 | # Basic valid cases - no warnings 104 | ((3, 5), 3, (3, 5), []), 105 | ((0, 3), 0, (0, 3), []), 106 | (5, 3, (3, 5), []), 107 | 108 | # Adjust values below min_value 109 | ( 110 | (1, 2), 111 | 3, 112 | (3, 3), 113 | ["test_field: Invalid kernel size range (1, 2). Values less than 3 are not allowed. Range automatically adjusted to (3, 3)."] 114 | ), 115 | # Adjust values below min_value (with automatic odd adjustment) 116 | ( 117 | (-1, 4), 118 | 0, 119 | (0, 5), 120 | [ 121 | "test_field: Non-zero kernel sizes must be odd. Range (0, 4) automatically adjusted to (0, 5)", 122 | "test_field: Invalid kernel size range (-1, 4). Values less than 0 are not allowed. Range automatically adjusted to (0, 4)." 123 | ] 124 | ), 125 | 126 | # Adjust non-odd values 127 | ( 128 | (3, 4), 129 | 3, 130 | (3, 5), 131 | ["test_field: Non-zero kernel sizes must be odd. Range (3, 4) automatically adjusted to (3, 5)."] 132 | ), 133 | ( 134 | (4, 8), 135 | 0, 136 | (5, 9), 137 | ["test_field: Non-zero kernel sizes must be odd. Range (4, 8) automatically adjusted to (5, 9)."] 138 | ), 139 | 140 | # Special case: keep zero values 141 | ( 142 | (0, 4), 143 | 0, 144 | (0, 5), 145 | ["test_field: Non-zero kernel sizes must be odd. Range (0, 4) automatically adjusted to (0, 5)."] 146 | ), 147 | 148 | # Fix min > max 149 | ( 150 | (7, 5), 151 | 3, 152 | (5, 5), 153 | ["test_field: Invalid range (7, 5) (min > max). Range automatically adjusted to (5, 5)."] 154 | ), 155 | # Multiple adjustments 156 | ( 157 | (2, 4), 158 | 3, 159 | (3, 5), 160 | [ 161 | "test_field: Invalid kernel size range (2, 4). Values less than 3 are not allowed. Range automatically adjusted to (3, 4).", 162 | "test_field: Non-zero kernel sizes must be odd. Range (3, 4) automatically adjusted to (3, 5).", 163 | ] 164 | ), 165 | ] 166 | ) 167 | def test_process_blur_limit( 168 | value: Any, 169 | min_value: int, 170 | expected: tuple[int, int], 171 | warning_messages: list[str] 172 | ) -> None: 173 | info = MockValidationInfo("test_field") 174 | 175 | with warnings.catch_warnings(record=True) as w: 176 | warnings.simplefilter("always") 177 | result = fblur.process_blur_limit(value, info, min_value) 178 | 179 | assert result == expected 180 | assert len(w) == len(warning_messages) 181 | 182 | 183 | def test_process_blur_limit_sequence_check() -> None: 184 | """Test that non-sequence values are properly converted to tuples.""" 185 | info = MockValidationInfo("test_field") 186 | 187 | # Test with integer input 188 | result = fblur.process_blur_limit(5, info, min_value=0) 189 | assert isinstance(result, tuple) 190 | assert result == (0, 5) 191 | 192 | # Test with float input 193 | result = fblur.process_blur_limit(5.0, info, min_value=0) 194 | assert isinstance(result, tuple) 195 | assert result == (0, 5) 196 | 197 | 198 | def compute_sharpness(image: np.ndarray) -> float: 199 | kernel = np.array([ 200 | [-1, -1, -1], 201 | [-1, 8, -1], 202 | [-1, -1, -1] 203 | ]) 204 | edges = cv2.filter2D(image.astype(np.float32), -1, kernel) 205 | return np.std(edges) 206 | 207 | def test_gaussian_blur_matches_pil(): 208 | # Create a test image with high-frequency details 209 | image = np.zeros((100, 100), dtype=np.uint8) 210 | image[::10, :] = 255 # horizontal lines 211 | image[:, ::10] = 255 # vertical lines 212 | 213 | # Test points 214 | sigmas = np.linspace(0.2, 10, 50) 215 | 216 | # Get blur progression for PIL 217 | pil_sharpness = [] 218 | alb_sharpness = [] 219 | 220 | pil_image = Image.fromarray(image) 221 | 222 | for sigma in sigmas: 223 | # PIL blur 224 | pil_blurred = pil_image.filter(ImageFilter.GaussianBlur(radius=sigma)) 225 | pil_sharpness.append(compute_sharpness(np.array(pil_blurred))) 226 | 227 | # Albumentations blur 228 | alb_blurred = A.GaussianBlur(blur_limit=0, sigma_limit=(sigma, sigma), p=1.0)(image=image)['image'] 229 | alb_sharpness.append(compute_sharpness(alb_blurred)) 230 | 231 | # Convert to numpy arrays for easier comparison 232 | pil_sharpness = np.array(pil_sharpness) 233 | alb_sharpness = np.array(alb_sharpness) 234 | 235 | # Compare curves directly using absolute differences 236 | abs_diff = np.abs(pil_sharpness - alb_sharpness) 237 | mean_diff = np.mean(abs_diff) 238 | max_diff = np.max(abs_diff) 239 | 240 | 241 | # Assert reasonable absolute differences 242 | assert mean_diff < 10, f"Average absolute difference too high: {mean_diff:.2f}" 243 | assert max_diff < 83, f"Maximum absolute difference too high: {max_diff:.2f}" 244 | -------------------------------------------------------------------------------- /tests/test_check_version.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from unittest.mock import MagicMock, patch 4 | 5 | import pytest 6 | 7 | from albumentations.check_version import ( 8 | check_for_updates, 9 | fetch_version_info, 10 | get_opener, 11 | parse_version, 12 | parse_version_parts, 13 | compare_versions, 14 | ) 15 | 16 | 17 | def test_get_opener(): 18 | opener = get_opener() 19 | assert opener is not None 20 | assert opener == get_opener() # Should return the same opener on subsequent calls 21 | 22 | 23 | @pytest.mark.parametrize( 24 | "status_code,expected_result", 25 | [ 26 | (200, '{"info": {"version": "1.0.0"}}'), 27 | (404, ""), 28 | ], 29 | ) 30 | def test_fetch_version_info(status_code, expected_result): 31 | mock_response = MagicMock() 32 | mock_response.status = status_code 33 | mock_response.read.return_value = expected_result.encode("utf-8") 34 | mock_response.__enter__.return_value = mock_response 35 | mock_response.__exit__.return_value = None 36 | 37 | mock_info = MagicMock() 38 | mock_info.get_content_charset.return_value = "utf-8" 39 | mock_response.info.return_value = mock_info 40 | 41 | mock_opener = MagicMock() 42 | mock_opener.open.return_value = mock_response 43 | 44 | with patch("albumentations.check_version.get_opener", return_value=mock_opener): 45 | result = fetch_version_info() 46 | assert result == expected_result 47 | 48 | @pytest.mark.parametrize( 49 | "input_data,expected_version", 50 | [ 51 | ('{"info": {"version": "1.0.0"}}', "1.0.0"), 52 | ('{"info": {}}', ""), 53 | ('{"other": "data"}', ""), 54 | ("invalid json", ""), 55 | ("", ""), 56 | ], 57 | ) 58 | def test_parse_version(input_data, expected_version): 59 | assert parse_version(input_data) == expected_version 60 | 61 | 62 | @pytest.mark.parametrize( 63 | "fetch_data,current_version,expected_warning", 64 | [ 65 | ('{"info": {"version": "1.0.1"}}', "1.0.0", True), 66 | ('{"info": {"version": "1.0.0"}}', "1.0.0", False), 67 | ('{"info": {}}', "1.0.0", False), 68 | ("invalid json", "1.0.0", False), 69 | ], 70 | ) 71 | def test_check_for_updates(fetch_data, current_version, expected_warning): 72 | with ( 73 | patch("albumentations.check_version.fetch_version_info", return_value=fetch_data), 74 | patch("albumentations.check_version.current_version", current_version), 75 | patch("albumentations.check_version.warn") as mock_warn, 76 | ): 77 | check_for_updates() 78 | assert mock_warn.called == expected_warning 79 | 80 | 81 | def test_check_for_updates_exception(): 82 | with ( 83 | patch("albumentations.check_version.fetch_version_info", side_effect=Exception("Test error")), 84 | patch("albumentations.check_version.warn") as mock_warn, 85 | ): 86 | check_for_updates() 87 | mock_warn.assert_called_once() 88 | assert "Failed to check for updates" in mock_warn.call_args[0][0] 89 | 90 | @pytest.mark.parametrize("response_data, expected_version", [ 91 | ('{"info": {"version": "1.0.0"}}', "1.0.0"), 92 | ('{"info": {}}', ""), 93 | ('{}', ""), 94 | ('', ""), 95 | ('invalid json', ""), 96 | ]) 97 | def test_parse_version(response_data: str, expected_version: str): 98 | assert parse_version(response_data) == expected_version 99 | 100 | def test_fetch_version_info_success(): 101 | mock_response = MagicMock() 102 | mock_response.status = 200 103 | mock_response.read.return_value = b'{"info": {"version": "1.0.0"}}' 104 | mock_response.info.return_value.get_content_charset.return_value = "utf-8" 105 | # Set up the context manager behavior 106 | mock_response.__enter__.return_value = mock_response 107 | mock_response.__exit__.return_value = None 108 | 109 | mock_opener = MagicMock() 110 | mock_opener.open.return_value = mock_response 111 | 112 | with patch("albumentations.check_version.get_opener", return_value=mock_opener): 113 | data = fetch_version_info() 114 | assert data == '{"info": {"version": "1.0.0"}}' 115 | 116 | def test_fetch_version_info_failure(): 117 | with patch("urllib.request.OpenerDirector.open", side_effect=Exception("Network error")): 118 | data = fetch_version_info() 119 | assert data == "" 120 | 121 | def test_check_for_updates_no_update(): 122 | with patch("albumentations.check_version.fetch_version_info", return_value='{"info": {"version": "1.0.0"}}'): 123 | with patch("albumentations.check_version.__version__", "1.0.0"): 124 | with patch("warnings.warn") as mock_warn: 125 | check_for_updates() 126 | mock_warn.assert_not_called() 127 | 128 | def test_check_for_updates_with_update(): 129 | with patch("albumentations.check_version.fetch_version_info", return_value='{"info": {"version": "2.0.0"}}'): 130 | with patch("albumentations.check_version.current_version", "1.0.0"): 131 | with patch("albumentations.check_version.warn") as mock_warn: # Patch the imported warn 132 | check_for_updates() 133 | mock_warn.assert_called_once() 134 | 135 | 136 | 137 | @pytest.mark.parametrize("version_str, expected", [ 138 | # Standard versions 139 | ("1.4.24", (1, 4, 24)), 140 | ("0.0.1", (0, 0, 1)), 141 | ("10.20.30", (10, 20, 30)), 142 | 143 | # Pre-release versions 144 | ("1.4beta", (1, 4, "beta")), 145 | ("1.4beta2", (1, 4, "beta", 2)), 146 | ("1.4.beta2", (1, 4, "beta", 2)), 147 | ("1.4.alpha2", (1, 4, "alpha", 2)), 148 | ("1.4rc1", (1, 4, "rc", 1)), 149 | ("1.4.rc.1", (1, 4, "rc", 1)), 150 | 151 | # Mixed case handling 152 | ("1.4Beta2", (1, 4, "beta", 2)), 153 | ("1.4ALPHA2", (1, 4, "alpha", 2)), 154 | ]) 155 | def test_parse_version_parts(version_str: str, expected: tuple[int | str, ...]) -> None: 156 | assert parse_version_parts(version_str) == expected 157 | 158 | # Update the test to use the new comparison function 159 | @pytest.mark.parametrize("version1, version2, expected", [ 160 | # Pre-release ordering 161 | ("1.4beta2", "1.4beta1", True), 162 | ("1.4", "1.4beta", True), 163 | ("1.4beta", "1.4alpha", True), 164 | ("1.4alpha2", "1.4alpha1", True), 165 | ("1.4rc", "1.4beta", True), 166 | ("2.0", "2.0rc1", True), 167 | 168 | # Standard version ordering 169 | ("1.5", "1.4", True), 170 | ("1.4.1", "1.4", True), 171 | ("1.4.24", "1.4.23", True), 172 | ]) 173 | def test_version_comparison(version1: str, version2: str, expected: bool) -> None: 174 | """Test that version1 > version2 matches expected result.""" 175 | v1 = parse_version_parts(version1) 176 | v2 = parse_version_parts(version2) 177 | assert compare_versions(v1, v2) == expected 178 | -------------------------------------------------------------------------------- /tests/test_hub_mixin.py: -------------------------------------------------------------------------------- 1 | """Tests for the HubMixin class.""" 2 | 3 | import platform 4 | from pathlib import Path 5 | from unittest.mock import patch 6 | 7 | import pytest 8 | 9 | from albumentations.core.hub_mixin import HubMixin, is_huggingface_hub_available 10 | 11 | # Skip tests if huggingface_hub is not available 12 | pytestmark = pytest.mark.skipif( 13 | not is_huggingface_hub_available, 14 | reason="huggingface_hub is not available" 15 | ) 16 | 17 | 18 | class TestTransform(HubMixin): 19 | """Test class for HubMixin.""" 20 | 21 | def __init__(self): 22 | """Initialize test transform.""" 23 | pass 24 | 25 | 26 | @pytest.mark.parametrize( 27 | ["path_string", "expected_posix"], 28 | [ 29 | ("normal/path/format", "normal/path/format"), 30 | ("windows\\path\\format", "windows/path/format"), 31 | ("mixed/path\\format", "mixed/path/format"), 32 | (Path("windows\\path\\format"), "windows/path/format"), 33 | ], 34 | ) 35 | def test_windows_path_handling(path_string, expected_posix): 36 | """Test that Windows paths are handled correctly in from_pretrained. 37 | 38 | This test verifies that backslashes in Windows paths are properly converted to forward slashes 39 | when passed to huggingface_hub.hf_hub_download in the from_pretrained method. 40 | 41 | Args: 42 | path_string: Input path with various formats 43 | expected_posix: Expected path after conversion to POSIX format 44 | """ 45 | with patch("albumentations.core.hub_mixin.hf_hub_download") as mock_download: 46 | mock_download.return_value = "config.json" 47 | 48 | # Also mock _from_pretrained to avoid file operations 49 | with patch.object(TestTransform, "_from_pretrained") as mock_from_pretrained: 50 | mock_from_pretrained.return_value = "mocked_transform" 51 | 52 | transform = TestTransform() 53 | transform.from_pretrained(path_string) 54 | 55 | # Check that the repo_id argument was properly formatted 56 | called_args, _ = mock_download.call_args 57 | assert called_args == () 58 | called_kwargs = mock_download.call_args.kwargs 59 | assert called_kwargs["repo_id"] == expected_posix 60 | 61 | 62 | @pytest.mark.skipif(platform.system() != "Windows", reason="Test only relevant on Windows") 63 | def test_real_windows_paths(): 64 | """Test with real Windows paths when running on Windows.""" 65 | with patch("albumentations.core.hub_mixin.hf_hub_download") as mock_download: 66 | mock_download.return_value = "mocked_file_path" 67 | 68 | with patch.object(TestTransform, "_from_pretrained") as mock_from_pretrained: 69 | mock_from_pretrained.return_value = "mocked_transform" 70 | 71 | # Use a Windows-style path string 72 | windows_path = "C:\\Users\\test\\models\\my_model" 73 | TestTransform.from_pretrained(windows_path) 74 | 75 | # Verify the repo_id 76 | repo_id = mock_download.call_args[1]["repo_id"] 77 | assert "\\" not in repo_id, f"Backslash found in repo_id: {repo_id}" 78 | assert repo_id == "C:/Users/test/models/my_model" 79 | -------------------------------------------------------------------------------- /tests/test_mixing.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Any 3 | 4 | import numpy as np 5 | import pytest 6 | from deepdiff import DeepDiff 7 | 8 | import albumentations as A 9 | from tests.conftest import UINT8_IMAGES, FLOAT32_IMAGES, MULTI_IMAGES 10 | 11 | 12 | def image_generator(): 13 | yield {"image": np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)} 14 | 15 | 16 | def complex_image_generator(): 17 | height = 100 18 | width = 100 19 | yield {"image": (height, width)} 20 | 21 | 22 | def complex_read_fn_image(x): 23 | return {"image": np.random.randint(0, 256, (x["image"][0], x["image"][1], 3), dtype=np.uint8)} 24 | 25 | 26 | # Mock random.randint to produce consistent results 27 | @pytest.fixture(autouse=True) 28 | def mock_random(monkeypatch): 29 | def mock_randint(start, end): 30 | return start # always return the start value for consistency in tests 31 | 32 | monkeypatch.setattr(random, "randint", mock_randint) 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "metadata, img_shape, expected_output", 37 | [ 38 | ( 39 | # Image + bbox without label + mask + mask_id + label_id + no offset 40 | { 41 | "image": np.ones((20, 20, 3), dtype=np.uint8) * 255, 42 | "bbox": [0.3, 0.3, 0.5, 0.5], 43 | "mask": np.ones((20, 20), dtype=np.uint8) * 127, 44 | "mask_id": 1, 45 | "bbox_id": 99, 46 | }, 47 | (100, 100), 48 | { 49 | "overlay_image": np.ones((20, 20, 3), dtype=np.uint8) * 255, 50 | "overlay_mask": np.ones((20, 20), dtype=np.uint8) * 127, 51 | "offset": (30, 30), 52 | "mask_id": 1, 53 | "bbox": [30, 30, 50, 50, 99], 54 | }, 55 | ), 56 | # Image + bbox with label + mask_id + no mask 57 | ( 58 | {"image": np.ones((20, 20, 3), dtype=np.uint8) * 255, "bbox": [0.3, 0.3, 0.5, 0.5, 99], "mask_id": 1}, 59 | (100, 100), 60 | { 61 | "overlay_image": np.ones((20, 20, 3), dtype=np.uint8) * 255, 62 | "overlay_mask": np.ones((20, 20), dtype=np.uint8), 63 | "offset": (30, 30), 64 | "mask_id": 1, 65 | "bbox": [30, 30, 50, 50, 99], 66 | }, 67 | ), 68 | # Test case with triangular mask 69 | ( 70 | { 71 | "image": np.ones((20, 20, 3), dtype=np.uint8) * 255, 72 | "bbox": [0, 0, 0.2, 0.2], 73 | "mask": np.tri(20, 20, dtype=np.uint8) * 127, 74 | "mask_id": 2, 75 | "bbox_id": 100, 76 | }, 77 | (100, 100), 78 | { 79 | "overlay_image": np.ones((20, 20, 3), dtype=np.uint8) * 255, 80 | "overlay_mask": np.tri(20, 20, dtype=np.uint8) * 127, 81 | "offset": (0, 0), 82 | "mask_id": 2, 83 | "bbox": [0, 0, 20, 20, 100], 84 | }, 85 | ), 86 | # Test case with overlay_image having the same size as img_shape 87 | ( 88 | { 89 | "image": np.ones((100, 100, 3), dtype=np.uint8) * 255, 90 | "bbox": [0, 0, 1, 1], 91 | "mask": np.ones((100, 100), dtype=np.uint8) * 127, 92 | "mask_id": 3, 93 | "bbox_id": 101, 94 | }, 95 | (100, 100), 96 | { 97 | "overlay_image": np.ones((100, 100, 3), dtype=np.uint8) * 255, 98 | "overlay_mask": np.ones((100, 100), dtype=np.uint8) * 127, 99 | "offset": (0, 0), 100 | "mask_id": 3, 101 | "bbox": [0, 0, 100, 100, 101], 102 | }, 103 | ), 104 | ], 105 | ) 106 | def test_preprocess_metadata(metadata: dict[str, Any], img_shape: tuple[int, int], expected_output: dict[str, Any]): 107 | result = A.OverlayElements.preprocess_metadata(metadata, img_shape, random.Random(0)) 108 | 109 | assert DeepDiff(result, expected_output, ignore_type_in_groups=[(tuple, list)]) == {} 110 | 111 | 112 | @pytest.mark.parametrize( 113 | "metadata, expected_output", 114 | [ 115 | ( 116 | { 117 | "image": np.ones((10, 10, 3), dtype=np.uint8) * 255, 118 | "bbox": [0.1, 0.2, 0.2, 0.3], 119 | }, 120 | { 121 | "expected_overlay": np.ones((10, 10, 3), dtype=np.uint8) * 255, 122 | "expected_bbox": [10, 20, 20, 30], 123 | }, 124 | ), 125 | ( 126 | { 127 | "image": np.ones((10, 10, 3), dtype=np.uint8) * 255, 128 | "bbox": [0.3, 0.4, 0.4, 0.5], 129 | "label_id": 99, 130 | }, 131 | { 132 | "expected_overlay": np.ones((10, 10, 3), dtype=np.uint8) * 255, 133 | "expected_bbox": [30, 40, 40, 50, 99], 134 | }, 135 | ), 136 | ( 137 | { 138 | "image": np.ones((10, 10, 3), dtype=np.uint8) * 255, 139 | }, 140 | { 141 | "expected_overlay": np.ones((10, 10, 3), dtype=np.uint8) * 255, 142 | "expected_bbox": [0, 0, 10, 10], 143 | }, 144 | ), 145 | ], 146 | ) 147 | def test_end_to_end(metadata, expected_output): 148 | transform = A.Compose([A.OverlayElements(p=1)], strict=True) 149 | 150 | img = np.zeros((100, 100, 3), dtype=np.uint8) 151 | 152 | transformed = transform(image=img, overlay_metadata=metadata) 153 | 154 | expected_img = np.zeros((100, 100, 3), dtype=np.uint8) 155 | 156 | x_min, y_min, x_max, y_max = expected_output["expected_bbox"][:4] 157 | 158 | expected_img[y_min:y_max, x_min:x_max] = expected_output["expected_overlay"] 159 | 160 | if "bbox" in metadata: 161 | np.testing.assert_array_equal(transformed["image"], expected_img) 162 | else: 163 | assert expected_img.sum() == transformed["image"].sum() 164 | -------------------------------------------------------------------------------- /tests/test_other.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from albucore.utils import get_max_value 4 | 5 | from albumentations.augmentations import FromFloat, ToFloat 6 | 7 | 8 | @pytest.mark.parametrize("dtype", ["uint8", "uint16", "float32", "float64"]) 9 | @pytest.mark.parametrize( 10 | "param, shape", 11 | [ 12 | ("image", (8, 7, 6)), 13 | ("image", (8, 7)), 14 | ("images", (4, 8, 7, 6)), 15 | ("images", (4, 8, 7)), 16 | ] 17 | ) 18 | def test_to_float(param, shape, dtype): 19 | rng = np.random.default_rng() 20 | data = rng.uniform(0, 10, size=shape).astype(dtype) 21 | 22 | aug = ToFloat() 23 | result = aug(**{param: data})[param] 24 | 25 | assert result.dtype == np.float32 26 | assert result.shape == data.shape 27 | np.testing.assert_allclose(data, result * get_max_value(np.dtype(dtype))) 28 | 29 | 30 | @pytest.mark.parametrize("dtype", ["uint8", "uint16"]) 31 | @pytest.mark.parametrize( 32 | "param, shape", 33 | [ 34 | ("image", (8, 7, 6)), 35 | ("image", (8, 7)), 36 | ("images", (4, 8, 7, 6)), 37 | ("images", (4, 8, 7)), 38 | ] 39 | ) 40 | def test_from_float(param, shape, dtype): 41 | rng = np.random.default_rng() 42 | data = rng.random(size=shape, dtype=np.float32) 43 | 44 | aug = FromFloat(dtype=dtype) 45 | result = aug(**{param: data})[param] 46 | 47 | assert result.dtype == np.dtype(dtype) 48 | assert result.shape == data.shape 49 | # Because FromFloat has to round to the nearest integer, we get an absolute difference up to 0.5 50 | np.testing.assert_allclose(data * get_max_value(np.dtype(dtype)), result, atol=0.5) 51 | -------------------------------------------------------------------------------- /tests/test_resize_area_downscale.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import pytest 4 | 5 | import albumentations as A 6 | 7 | 8 | def get_downscale_image(): 9 | """Create 100x100 test image for downscaling tests.""" 10 | image = np.zeros((100, 100, 3), dtype=np.uint8) 11 | # Add recognizable pattern 12 | image[20:80, 20:80, 0] = 255 # Red square 13 | image[40:60, 40:60, 1] = 255 # Green square in the middle 14 | return image 15 | 16 | 17 | def get_upscale_image(): 18 | """Create 50x50 test image for upscaling tests.""" 19 | image = np.zeros((50, 50, 3), dtype=np.uint8) 20 | # Add recognizable pattern 21 | image[10:40, 10:40, 0] = 255 # Red square 22 | image[20:30, 20:30, 1] = 255 # Green square in the middle 23 | return image 24 | 25 | 26 | def get_mask(size): 27 | """Create a mask of the specified size.""" 28 | mask = np.zeros((size, size), dtype=np.uint8) 29 | # Add a simple pattern 30 | center = size // 2 31 | mask[center - size//4:center + size//4, center - size//4:center + size//4] = 1 32 | return mask 33 | 34 | 35 | @pytest.mark.parametrize( 36 | ["transform_cls", "downscale_params", "upscale_params"], 37 | [ 38 | ( 39 | A.RandomScale, 40 | {"scale_limit": (-0.5, -0.5)}, # Fixed 0.5 downscale (100px→50px) 41 | {"scale_limit": (1.0, 1.0)} # Fixed 2.0 upscale (50px→100px) 42 | ), 43 | ( 44 | A.Resize, 45 | {"height": 50, "width": 50}, # Downscale 100px→50px 46 | {"height": 100, "width": 100} # Upscale 50px→100px 47 | ), 48 | ( 49 | A.LongestMaxSize, 50 | {"max_size": 50}, # Downscale 100px→50px 51 | {"max_size": 100} # Upscale 50px→100px 52 | ), 53 | ( 54 | A.SmallestMaxSize, 55 | {"max_size": 50}, # Downscale 100px→50px 56 | {"max_size": 100} # Upscale 50px→100px 57 | ), 58 | ], 59 | ) 60 | @pytest.mark.parametrize( 61 | "interpolation", 62 | [ 63 | cv2.INTER_NEAREST, 64 | cv2.INTER_LINEAR, 65 | cv2.INTER_CUBIC, 66 | cv2.INTER_LANCZOS4, 67 | ], 68 | ) 69 | class TestAreaForDownscaleOutput: 70 | def test_downscale_area_option_matches_area_interp(self, transform_cls, downscale_params, upscale_params, interpolation): 71 | """Test that specified interpolation with area_for_downscale='image' matches AREA without area_for_downscale for downscaling.""" 72 | image = get_downscale_image() 73 | 74 | # Transform 1: Specified interpolation with area_for_downscale="image" 75 | transform1 = transform_cls( 76 | interpolation=interpolation, 77 | area_for_downscale="image", 78 | p=1.0, 79 | **downscale_params 80 | ) 81 | result1 = transform1(image=image) 82 | 83 | # Transform 2: AREA interpolation without area_for_downscale 84 | transform2 = transform_cls( 85 | interpolation=cv2.INTER_AREA, 86 | area_for_downscale=None, 87 | p=1.0, 88 | **downscale_params 89 | ) 90 | result2 = transform2(image=image) 91 | 92 | # The image outputs should be identical 93 | np.testing.assert_array_equal( 94 | result1["image"], 95 | result2["image"], 96 | err_msg=f"Downscale outputs differ for {transform_cls.__name__} with {interpolation}+area_for_image vs AREA" 97 | ) 98 | 99 | def test_downscale_area_option_for_mask(self, transform_cls, downscale_params, upscale_params, interpolation): 100 | """Test that area_for_downscale='image_mask' affects mask interpolation.""" 101 | image = get_downscale_image() 102 | mask = get_mask(100) 103 | 104 | # Transform 1: With area_for_downscale="image" (should not affect mask) 105 | transform1 = transform_cls( 106 | interpolation=interpolation, 107 | mask_interpolation=cv2.INTER_NEAREST, 108 | area_for_downscale="image", 109 | p=1.0, 110 | **downscale_params 111 | ) 112 | result1 = transform1(image=image, mask=mask) 113 | 114 | # Transform 2: With area_for_downscale="image_mask" (should affect mask) 115 | transform2 = transform_cls( 116 | interpolation=interpolation, 117 | mask_interpolation=cv2.INTER_NEAREST, 118 | area_for_downscale="image_mask", 119 | p=1.0, 120 | **downscale_params 121 | ) 122 | result2 = transform2(image=image, mask=mask) 123 | 124 | # Both should produce identical images (both using AREA for downscaling) 125 | np.testing.assert_array_equal( 126 | result1["image"], 127 | result2["image"], 128 | err_msg=f"Image outputs differ between area_for_downscale='image' and 'image_mask'" 129 | ) 130 | 131 | # Transform 3: With AREA mask interpolation (should match area_for_downscale="image_mask") 132 | transform3 = transform_cls( 133 | interpolation=interpolation, 134 | mask_interpolation=cv2.INTER_AREA, 135 | area_for_downscale=None, 136 | p=1.0, 137 | **downscale_params 138 | ) 139 | result3 = transform3(image=image, mask=mask) 140 | 141 | # Mask from transform2 (NEAREST+area_for_image_mask) should match transform3 (AREA) 142 | np.testing.assert_array_equal( 143 | result2["mask"], 144 | result3["mask"], 145 | err_msg=f"Mask with NEAREST+area_for_image_mask should match mask with AREA interpolation" 146 | ) 147 | 148 | def test_upscale_ignores_area_for_downscale(self, transform_cls, downscale_params, upscale_params, interpolation): 149 | """Test that area_for_downscale has no effect when upscaling.""" 150 | image = get_upscale_image() 151 | 152 | # Transform 1: With area_for_downscale="image" 153 | transform1 = transform_cls( 154 | interpolation=interpolation, 155 | area_for_downscale="image", 156 | p=1.0, 157 | **upscale_params 158 | ) 159 | result1 = transform1(image=image) 160 | 161 | # Transform 2: Without area_for_downscale 162 | transform2 = transform_cls( 163 | interpolation=interpolation, 164 | area_for_downscale=None, 165 | p=1.0, 166 | **upscale_params 167 | ) 168 | result2 = transform2(image=image) 169 | 170 | # Outputs should be identical since area_for_downscale shouldn't affect upscaling 171 | np.testing.assert_array_equal( 172 | result1["image"], 173 | result2["image"], 174 | err_msg=f"Upscale outputs differ with/without area_for_downscale" 175 | ) 176 | -------------------------------------------------------------------------------- /tests/test_targets.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | import albumentations as A 7 | from albumentations.core.type_definitions import ALL_TARGETS, Targets 8 | 9 | from tests.conftest import SQUARE_FLOAT_IMAGE 10 | from .utils import get_dual_transforms, get_image_only_transforms 11 | 12 | 13 | def get_targets_from_methods(cls): 14 | targets = {Targets.IMAGE, Targets.MASK, Targets.VOLUME, Targets.MASK3D} 15 | 16 | has_bboxes_method = any( 17 | hasattr(cls, attr) and getattr(cls, attr) is not getattr(A.DualTransform, attr, None) 18 | for attr in ["apply_to_bbox", "apply_to_bboxes"] 19 | ) 20 | if has_bboxes_method: 21 | targets.add(Targets.BBOXES) 22 | 23 | has_keypoints_method = any( 24 | hasattr(cls, attr) and getattr(cls, attr) is not getattr(A.DualTransform, attr, None) 25 | for attr in ["apply_to_keypoint", "apply_to_keypoints"] 26 | ) 27 | if has_keypoints_method: 28 | targets.add(Targets.KEYPOINTS) 29 | 30 | return targets 31 | 32 | 33 | def extract_targets_from_docstring(cls): 34 | # Access the class's docstring 35 | if not (docstring := cls.__doc__): 36 | return [] # Return an empty list if there's no docstring 37 | 38 | # Regular expression to match the 'Targets:' section in the docstring 39 | targets_pattern = r"Targets:\s*([^\n]+)" 40 | 41 | # Search for the pattern in the docstring and extract targets if found 42 | if matches := re.search(targets_pattern, docstring): 43 | # Extract the targets string and split it by commas or spaces 44 | targets = re.split(r"[,\s]+", matches[1]) # Using subscript notation instead of group() 45 | return [target.strip() for target in targets if target.strip()] # Remove any extra whitespace 46 | 47 | return [] # Return an empty list if the 'Targets:' section isn't found 48 | 49 | 50 | DUAL_TARGETS = { 51 | A.OverlayElements: (Targets.IMAGE, Targets.MASK), 52 | A.Mosaic: (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS), 53 | } 54 | 55 | 56 | str2target = { 57 | "image": Targets.IMAGE, 58 | "mask": Targets.MASK, 59 | "bboxes": Targets.BBOXES, 60 | "keypoints": Targets.KEYPOINTS, 61 | "volume": Targets.VOLUME, 62 | "mask3d": Targets.MASK3D, 63 | } 64 | 65 | 66 | @pytest.mark.parametrize( 67 | ["augmentation_cls", "params"], 68 | get_image_only_transforms( 69 | custom_arguments={ 70 | A.TextImage: dict(font_path="./tests/filesLiberationSerif-Bold.ttf"), 71 | }, 72 | ), 73 | ) 74 | def test_image_only(augmentation_cls, params): 75 | aug = augmentation_cls(p=1, **params) 76 | assert aug._targets == (Targets.IMAGE, Targets.VOLUME) 77 | 78 | 79 | @pytest.mark.parametrize( 80 | ["augmentation_cls", "params"], 81 | get_dual_transforms( 82 | custom_arguments={ 83 | A.Crop: {"y_min": 0, "y_max": 10, "x_min": 0, "x_max": 10}, 84 | A.CenterCrop: {"height": 10, "width": 10}, 85 | A.CropNonEmptyMaskIfExists: {"height": 10, "width": 10}, 86 | A.RandomCrop: {"height": 10, "width": 10}, 87 | A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10}, 88 | A.RandomResizedCrop: {"size": (10, 10)}, 89 | A.RandomSizedCrop: {"min_max_height": (4, 8), "size": (10, 10)}, 90 | A.RandomSizedBBoxSafeCrop: {"height": 10, "width": 10}, 91 | A.CropAndPad: {"px": 10}, 92 | A.Resize: {"height": 10, "width": 10}, 93 | A.XYMasking: { 94 | "num_masks_x": (1, 3), 95 | "num_masks_y": (1, 3), 96 | "mask_x_length": 10, 97 | "mask_y_length": 10, 98 | "fill_mask": 1, 99 | "fill": 0, 100 | }, 101 | A.GridElasticDeform: {"num_grid_xy": (10, 10), "magnitude": 10}, 102 | A.Mosaic: {}, 103 | }, 104 | ), 105 | ) 106 | def test_dual(augmentation_cls, params): 107 | aug = augmentation_cls(p=1, **params) 108 | assert set(aug._targets) == set(DUAL_TARGETS.get(augmentation_cls, ALL_TARGETS)) 109 | assert set(aug._targets) <= get_targets_from_methods(augmentation_cls) 110 | 111 | targets_from_docstring = {str2target[target] for target in extract_targets_from_docstring(augmentation_cls)} 112 | 113 | assert set(aug._targets) == targets_from_docstring 114 | -------------------------------------------------------------------------------- /tests/test_text.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import random 4 | from albucore import get_num_channels 5 | from PIL import Image, ImageFont 6 | 7 | import albumentations.augmentations.text.functional as ftext 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "sentence, num_words, expected_length", 12 | [ 13 | ("The quick brown fox jumps over the lazy dog", 0, 9), # num_words=0 should delete 0 words 14 | ("The quick brown fox jumps over the lazy dog", 9, 0), # num_words=9 should delete all words 15 | ("The quick brown fox jumps over the lazy dog", 4, 5), # num_words=4 should delete 4 words 16 | ("Hello world", 1, 1), # num_words=1 should delete 1 word 17 | ("Hello", 1, 0), # Single word sentence should be deleted 18 | ], 19 | ) 20 | def test_delete_random_words(sentence, num_words, expected_length): 21 | words = sentence.split() 22 | result = ftext.delete_random_words(words, num_words, random.Random(42)) 23 | result_length = len(result.split()) 24 | assert expected_length == result_length 25 | if num_words == 0: 26 | assert result == sentence # No words should be deleted 27 | if num_words >= len(sentence.split()): 28 | assert result == "" # All words should be deleted 29 | 30 | 31 | @pytest.mark.parametrize( 32 | "sentence, num_words", 33 | [ 34 | ("The quick brown fox jumps over the lazy dog", 0), # No swaps 35 | ("The quick brown fox jumps over the lazy dog", 1), # One swap 36 | ("The quick brown fox jumps over the lazy dog", 3), # Three swaps 37 | ("Hello world", 1), # Single swap for two words 38 | ("Hello", 1), # Single word should remain unchanged 39 | ], 40 | ) 41 | def test_swap_random_words(sentence, num_words): 42 | 43 | words_in_sentence = sentence.split(" ") 44 | 45 | result = ftext.swap_random_words(words_in_sentence, num_words, random.Random(42)) 46 | words_in_result = result.split(" ") 47 | 48 | # Handle single word case 49 | if len(words_in_sentence) == 1: 50 | assert result == sentence, "Single word input should remain unchanged" 51 | else: 52 | assert ( 53 | words_in_result != words_in_sentence or num_words == 0 54 | ), f"Result should be different from input for n={num_words}" 55 | assert len(words_in_result) == len(words_in_sentence), "Result should have the same number of words as input" 56 | assert sorted(words_in_result) == sorted(words_in_sentence), "Result should contain the same words as input" 57 | 58 | 59 | @pytest.mark.parametrize( 60 | "sentence, num_insertions, stopwords, expected_length_range", 61 | [ 62 | ("The quick brown fox jumps over the lazy dog", 0, None, (9, 9)), # No insertions 63 | ("The quick brown fox jumps over the lazy dog", 1, None, (10, 10)), # One insertion with default stopwords 64 | ("The quick brown fox jumps over the lazy dog", 3, None, (12, 12)), # Three insertions with default stopwords 65 | ( 66 | "The quick brown fox jumps over the lazy dog", 67 | 1, 68 | ["a", "b", "c"], 69 | (10, 10), 70 | ), # One insertion with custom stopwords 71 | ("Hello world", 1, None, (3, 3)), # Single insertion for two words 72 | ("Hello", 1, None, (2, 2)), # Single insertion for a single word 73 | ], 74 | ) 75 | def test_insert_random_stopwords(sentence, num_insertions, stopwords, expected_length_range): 76 | words = sentence.split() 77 | result = ftext.insert_random_stopwords(words, num_insertions, stopwords, random.Random(42)) 78 | result_length = len(result.split()) 79 | 80 | # Ensure the result length is within the expected range 81 | assert ( 82 | expected_length_range[0] <= result_length <= expected_length_range[1] 83 | ), f"Result length {result_length} not in expected range {expected_length_range} for input '{sentence}' with num_insertions={num_insertions}" 84 | 85 | # Check if the number of words increased correctly 86 | assert ( 87 | result_length == len(sentence.split()) + num_insertions 88 | ), "The number of words in the result should be the original number plus the number of insertions" 89 | 90 | # Ensure all inserted words are from the stopwords list 91 | if stopwords is None: 92 | stopwords = ["and", "the", "is", "in", "at", "of"] 93 | inserted_words = [word for word in result.split() if word not in sentence.split()] 94 | assert all(word in stopwords for word in inserted_words), "All inserted words should be from the stopwords list" 95 | 96 | 97 | @pytest.mark.parametrize( 98 | "image_shape", 99 | [ 100 | (100, 100), # Grayscale image 101 | (100, 100, 1), # Single channel image 102 | (100, 100, 3), # RGB image 103 | ], 104 | ) 105 | def test_convert_image_to_pil(image_shape): 106 | image = np.random.randint(0, 255, size=image_shape, dtype=np.uint8) 107 | pil_image = ftext.convert_image_to_pil(image) 108 | assert isinstance(pil_image, Image.Image) 109 | 110 | 111 | font = ImageFont.truetype("./tests/files/LiberationSerif-Bold.ttf", size=12) 112 | 113 | dummy_metadata = { 114 | "bbox_coords": (10, 10, 100, 50), 115 | "text": "Test", 116 | "font": font, 117 | "font_color": (255, 0, 0), # Red color 118 | } 119 | 120 | 121 | @pytest.mark.parametrize( 122 | "image_shape, metadata_list", 123 | [ 124 | ( 125 | (100, 100), 126 | [ 127 | { 128 | "bbox_coords": (10, 10, 100, 50), 129 | "text": "Test", 130 | "font": font, 131 | "font_color": (127,), # Grayscale color 132 | }, 133 | ], 134 | ), # Grayscale image 135 | ( 136 | (100, 100, 1), 137 | [ 138 | { 139 | "bbox_coords": (10, 10, 100, 50), 140 | "text": "Test", 141 | "font": font, 142 | "font_color": (127,), # Single channel color 143 | }, 144 | ], 145 | ), # Single channel image 146 | ( 147 | (100, 100, 3), 148 | [ 149 | { 150 | "bbox_coords": (10, 10, 100, 50), 151 | "text": "Test", 152 | "font": font, 153 | "font_color": (127, 127, 127), # RGB color 154 | }, 155 | { 156 | "bbox_coords": (20, 20, 110, 60), 157 | "text": "Test", 158 | "font": font, 159 | "font_color": (255, 0, 0), # Red color 160 | }, 161 | ], 162 | ), # RGB image with tuple colors 163 | ( 164 | (100, 100, 5), 165 | [ 166 | { 167 | "bbox_coords": (20, 20, 110, 60), 168 | "text": "Test", 169 | "font": font, 170 | "font_color": (127, 127, 127, 127, 127), # 5-channel color 171 | }, 172 | ], 173 | ), 174 | ( 175 | (100, 100, 5), 176 | [ 177 | { 178 | "bbox_coords": (20, 20, 110, 60), 179 | "text": "Test", 180 | "font": font, 181 | "font_color": (255, 0, 0), # RGB color (will be padded for 5 channels) 182 | }, 183 | ], 184 | ), 185 | ], 186 | ) 187 | def test_draw_text_on_pil_image(image_shape, metadata_list): 188 | image = np.random.randint(0, 255, size=image_shape, dtype=np.uint8) 189 | 190 | if get_num_channels(image) in {1, 3}: 191 | pil_image = ftext.convert_image_to_pil(image) 192 | result = ftext.draw_text_on_pil_image(pil_image, metadata_list) 193 | assert isinstance(result, Image.Image) 194 | else: 195 | result = ftext.draw_text_on_multi_channel_image(image, metadata_list) 196 | assert isinstance(result, np.ndarray) 197 | -------------------------------------------------------------------------------- /tests/transforms3d/test_functions.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | import numpy as np 5 | from albumentations.augmentations.transforms3d import functional as f3d 6 | 7 | @pytest.mark.parametrize( 8 | "input_shape,n_channels", [ 9 | ((3, 3, 3), None), # 3D case 10 | ((3, 3, 3, 2), 2), # 4D case 11 | ] 12 | ) 13 | def test_uniqueness(input_shape: tuple, n_channels: int | None): 14 | # Create test cube with unique values 15 | n_elements = np.prod(input_shape) 16 | test_cube = np.arange(n_elements).reshape(input_shape) 17 | 18 | # Generate all 48 transformations 19 | transformations = [f3d.transform_cube(test_cube, i) for i in range(48)] 20 | 21 | # Check uniqueness 22 | unique_transforms = set(str(t) for t in transformations) 23 | assert len(unique_transforms) == 48, "Not all transformations are unique!" 24 | 25 | # Check shape preservation 26 | expected_shape = input_shape 27 | for t in transformations: 28 | assert t.shape == expected_shape, f"Wrong shape: got {t.shape}, expected {expected_shape}" 29 | 30 | 31 | @pytest.mark.parametrize( 32 | ["keypoints", "holes", "expected"], 33 | [ 34 | # Basic case: single hole, some points inside/outside 35 | ( 36 | np.array([[1, 1, 1], [5, 5, 5], [8, 8, 8]], dtype=np.float32), # keypoints (XYZ) 37 | np.array([[4, 4, 4, 6, 6, 6]], dtype=np.float32), # holes (Z1,Y1,X1,Z2,Y2,X2) 38 | np.array([[1, 1, 1], [8, 8, 8]], dtype=np.float32), # expected (points outside hole) 39 | ), 40 | # Multiple holes 41 | ( 42 | np.array([[1, 1, 1], [5, 5, 5], [8, 8, 8]], dtype=np.float32), 43 | np.array([ 44 | [0, 0, 0, 2, 2, 2], # hole covering [1,1,1] 45 | [4, 4, 4, 6, 6, 6], # hole covering [5,5,5] 46 | ], dtype=np.float32), 47 | np.array([[8, 8, 8]], dtype=np.float32), # only last point survives 48 | ), 49 | # Edge cases: points exactly on boundaries 50 | ( 51 | np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3]], dtype=np.float32), 52 | np.array([[1, 1, 1, 2, 2, 2]], dtype=np.float32), # hole with points on boundaries 53 | np.array([[2, 2, 2], [3, 3, 3]], dtype=np.float32), # points on/outside boundaries remain 54 | ), 55 | # Empty arrays 56 | ( 57 | np.array([], dtype=np.float32).reshape(0, 3), # no keypoints 58 | np.array([[1, 1, 1, 2, 2, 2]], dtype=np.float32), 59 | np.array([], dtype=np.float32).reshape(0, 3), 60 | ), 61 | ( 62 | np.array([[1, 1, 1], [2, 2, 2]], dtype=np.float32), 63 | np.array([], dtype=np.float32).reshape(0, 6), # no holes 64 | np.array([[1, 1, 1], [2, 2, 2]], dtype=np.float32), 65 | ), 66 | # Extra keypoint dimensions 67 | ( 68 | np.array([[1, 1, 1, 0.9], [5, 5, 5, 0.8]], dtype=np.float32), 69 | np.array([[4, 4, 4, 6, 6, 6]], dtype=np.float32), 70 | np.array([[1, 1, 1, 0.9]], dtype=np.float32), 71 | ), 72 | # Overlapping holes 73 | ( 74 | np.array([[1, 1, 1], [3, 3, 3], [5, 5, 5]], dtype=np.float32), 75 | np.array([ 76 | [0, 0, 0, 4, 4, 4], 77 | [2, 2, 2, 6, 6, 6], 78 | ], dtype=np.float32), 79 | np.array([], dtype=np.float32).reshape(0, 3), # all points inside holes 80 | ), 81 | ], 82 | ) 83 | def test_filter_keypoints_in_holes3d(keypoints, holes, expected): 84 | result = f3d.filter_keypoints_in_holes3d(keypoints, holes) 85 | np.testing.assert_array_equal( 86 | result, expected, 87 | err_msg=f"Failed with keypoints {keypoints} and holes {holes}" 88 | ) 89 | 90 | def test_filter_keypoints_in_holes3d_invalid_input(): 91 | """Test error handling for invalid input shapes.""" 92 | # Invalid keypoint dimensions 93 | with pytest.raises(IndexError): 94 | f3d.filter_keypoints_in_holes3d( 95 | np.array([[1, 1]]), # only 2D points 96 | np.array([[0, 0, 0, 1, 1, 1]]) 97 | ) 98 | 99 | # Invalid hole dimensions 100 | with pytest.raises(IndexError): 101 | f3d.filter_keypoints_in_holes3d( 102 | np.array([[1, 1, 1]]), 103 | np.array([[0, 0, 0, 1]]) # incomplete hole specification 104 | ) 105 | 106 | def test_filter_keypoints_in_holes3d_random(): 107 | """Test with random data to ensure robustness.""" 108 | rng = np.random.default_rng(42) 109 | 110 | # Generate random keypoints and holes 111 | num_keypoints = 100 112 | num_holes = 10 113 | volume_size = 100 114 | 115 | keypoints = rng.integers(0, volume_size, (num_keypoints, 3)) 116 | holes = np.array([ 117 | [ 118 | rng.integers(0, volume_size-10), # z1 119 | rng.integers(0, volume_size-10), # y1 120 | rng.integers(0, volume_size-10), # x1 121 | rng.integers(10, volume_size), # z2 122 | rng.integers(10, volume_size), # y2 123 | rng.integers(10, volume_size), # x2 124 | ] 125 | for _ in range(num_holes) 126 | ]) 127 | 128 | # Ensure z2>z1, y2>y1, x2>x1 for each hole 129 | holes[:, 3:] = holes[:, :3] + holes[:, 3:] 130 | 131 | # Test function 132 | result = f3d.filter_keypoints_in_holes3d(keypoints, holes) 133 | 134 | # Verify each surviving point is actually outside all holes 135 | for point in result: 136 | x, y, z = point 137 | for z1, y1, x1, z2, y2, x2 in holes: 138 | assert not ( 139 | z1 <= z < z2 and 140 | y1 <= y < y2 and 141 | x1 <= x < x2 142 | ), f"Point {point} should be outside hole {[z1,y1,x1,z2,y2,x2]}" 143 | 144 | @pytest.mark.parametrize( 145 | "factor", # Remove the brackets - it's just the parameter name 146 | [ 147 | 1, 148 | 2, 149 | 3, 150 | -1, 151 | 5, 152 | ] 153 | ) 154 | @pytest.mark.parametrize( 155 | "axes", # Remove the brackets - it's just the parameter name 156 | [ 157 | (0, 1), # rotate in HW plane 158 | (0, 2), # rotate in HD plane 159 | (1, 2), # rotate in WD plane 160 | ] 161 | ) 162 | def test_keypoints_rot90_matches_numpy(factor, axes): 163 | """Test that keypoints_rot90 matches np.rot90 behavior.""" 164 | # Create volume with different dimensions to catch edge cases 165 | volume = np.zeros((5, 6, 7), dtype=np.uint8) # (H, W, D) 166 | 167 | # Create test points (avoiding edges for clear results) 168 | keypoints = np.array([ 169 | [1, 1, 1], # XYZ coordinates 170 | [1, 3, 1], 171 | [3, 1, 3], 172 | [2, 2, 2], 173 | ], dtype=np.float32) 174 | 175 | # Convert keypoints from XYZ to HWD ordering 176 | keypoints_hwd = keypoints[:, [2, 1, 0]] # XYZ -> HWD 177 | 178 | # Mark points in volume 179 | for h, w, d in keypoints_hwd: 180 | volume[int(h), int(w), int(d)] = 1 181 | 182 | # Rotate volume 183 | rotated_volume = np.rot90(volume.copy(), factor, axes=axes) 184 | 185 | # Rotate keypoints 186 | rotated_keypoints_hwd = f3d.keypoints_rot90(keypoints_hwd, factor, axes, volume_shape=volume.shape) 187 | 188 | # Convert back to XYZ for verification 189 | rotated_keypoints = rotated_keypoints_hwd[:, [2, 1, 0]] # HWD -> XYZ 190 | 191 | # Verify each rotated keypoint matches a marked point in rotated volume 192 | for x, y, z in rotated_keypoints: 193 | assert rotated_volume[int(z), int(y), int(x)] == 1, ( 194 | f"Keypoint at ({x}, {y}, {z}) should match marked point in volume " 195 | f"after rotation with factor={factor}, axes={axes}" 196 | ) 197 | 198 | 199 | @pytest.mark.parametrize("index", range(48)) 200 | def test_transform_cube_keypoints_matches_transform_cube(index): 201 | """Test that transform_cube_keypoints matches transform_cube behavior and preserves extra columns.""" 202 | # Create volume with different dimensions to catch edge cases 203 | volume = np.zeros((5, 6, 7), dtype=np.uint8) # (D, H, W) 204 | 205 | # Create test points with additional columns 206 | keypoints = np.array([ 207 | [1, 1, 1, 0.5, 0.6, 0.7], # XYZ coordinates + 3 extra values 208 | [1, 3, 1, 0.2, 0.3, 0.4], 209 | [3, 1, 3, 0.8, 0.9, 1.0], 210 | [2, 2, 2, 0.1, 0.2, 0.3], 211 | ], dtype=np.float32) 212 | 213 | # Store original extra columns for comparison 214 | original_extra_cols = keypoints[:, 3:].copy() 215 | 216 | # Mark points in volume (converting from XYZ to DHW) 217 | for x, y, z in keypoints[:, :3]: 218 | volume[int(z), int(y), int(x)] = 1 219 | 220 | # Transform volume 221 | transformed_volume = f3d.transform_cube(volume.copy(), index) 222 | 223 | # Transform keypoints 224 | transformed_keypoints = f3d.transform_cube_keypoints(keypoints.copy(), index, volume_shape=volume.shape) 225 | 226 | # Verify each transformed keypoint matches a marked point in transformed volume 227 | for x, y, z in transformed_keypoints[:, :3]: 228 | assert transformed_volume[int(z), int(y), int(x)] == 1, ( 229 | f"Keypoint at ({x}, {y}, {z}) should match marked point in volume " 230 | f"after transformation with index={index}" 231 | ) 232 | 233 | # Verify extra columns remain unchanged 234 | np.testing.assert_array_equal( 235 | transformed_keypoints[:, 3:], 236 | original_extra_cols, 237 | err_msg=f"Extra columns should remain unchanged after transformation with index={index}" 238 | ) 239 | -------------------------------------------------------------------------------- /tests/transforms3d/test_pytorch.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import torch 4 | import albumentations as A 5 | 6 | test_cases = [ 7 | pytest.param( 8 | (64, 64, 64, 3), # volume shape 9 | (64, 64, 64, 1), # mask shape 10 | (3, 64, 64, 64), # expected volume shape 11 | (1, 64, 64, 64), # expected mask shape 12 | id="rgb_volume_single_channel_mask" 13 | ), 14 | pytest.param( 15 | (64, 64, 64), # volume shape 16 | (64, 64, 64), # mask shape 17 | (1, 64, 64, 64), # expected volume shape 18 | (1, 64, 64, 64), # expected mask shape 19 | id="grayscale_volume_and_mask" 20 | ), 21 | pytest.param( 22 | (64, 64, 64, 3), # volume shape 23 | (64, 64, 64, 4), # mask shape 24 | (3, 64, 64, 64), # expected volume shape 25 | (4, 64, 64, 64), # expected mask shape 26 | id="rgb_volume_multi_channel_mask" 27 | ), 28 | ] 29 | 30 | @pytest.mark.parametrize( 31 | "volume_shape,mask_shape,expected_volume_shape,expected_mask_shape", 32 | test_cases 33 | ) 34 | def test_to_tensor_3d_shapes( 35 | volume_shape, 36 | mask_shape, 37 | expected_volume_shape, 38 | expected_mask_shape 39 | ): 40 | transform = A.Compose([A.ToTensor3D(p=1)]) 41 | volume = np.random.randint(0, 256, volume_shape, dtype=np.uint8) 42 | mask3d = np.random.randint(0, 2, mask_shape, dtype=np.uint8) 43 | 44 | transformed = transform(volume=volume, mask3d=mask3d) 45 | 46 | assert isinstance(transformed["volume"], torch.Tensor) 47 | assert isinstance(transformed["mask3d"], torch.Tensor) 48 | assert transformed["volume"].shape == expected_volume_shape 49 | assert transformed["mask3d"].shape == expected_mask_shape 50 | 51 | 52 | error_test_cases = [ 53 | pytest.param( 54 | (64, 64), # 2D array 55 | TypeError, 56 | "volume must be 3D or 4D array", 57 | id="2d_array" 58 | ), 59 | pytest.param( 60 | (64, 64, 64, 3, 1), # 5D array 61 | TypeError, 62 | "volume must be 3D or 4D array", 63 | id="5d_array" 64 | ), 65 | ] 66 | 67 | @pytest.mark.parametrize( 68 | "volume_shape,expected_error,expected_message", 69 | error_test_cases 70 | ) 71 | def test_to_tensor_3d_errors(volume_shape, expected_error, expected_message): 72 | transform = A.Compose([A.ToTensor3D(p=1)]) 73 | volume = np.random.rand(*volume_shape) 74 | 75 | with pytest.raises(expected_error, match=expected_message): 76 | transform(volume=volume) 77 | -------------------------------------------------------------------------------- /tests/transforms3d/test_targets.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | import albumentations as A 7 | from albumentations.core.type_definitions import Targets 8 | from tests.utils import get_3d_transforms 9 | 10 | 11 | def extract_targets_from_docstring(cls): 12 | # Access the class's docstring 13 | docstring = cls.__doc__ 14 | if not docstring: 15 | return [] # Return an empty list if there's no docstring 16 | 17 | # Regular expression to match the 'Targets:' section in the docstring 18 | targets_pattern = r"Targets:\s*([^\n]+)" 19 | 20 | # Search for the pattern in the docstring 21 | matches = re.search(targets_pattern, docstring) 22 | if matches: 23 | # Extract the targets string and split it by commas or spaces 24 | targets_str = matches.group(1) 25 | targets = re.split(r"[,\s]+", targets_str) # Split by comma or whitespace 26 | return [target.strip() for target in targets if target.strip()] # Remove any extra whitespace 27 | return [] # Return an empty list if the 'Targets:' section isn't found 28 | 29 | 30 | def get_targets_from_methods(cls): 31 | targets = {Targets.VOLUME, Targets.MASK3D} 32 | 33 | has_volume_method = any( 34 | hasattr(cls, attr) and getattr(cls, attr) is not getattr(A.Transform3D, attr, None) 35 | for attr in ["apply_to_volume"] 36 | ) 37 | if has_volume_method: 38 | targets.add(Targets.VOLUME) 39 | 40 | has_masks_method = any( 41 | hasattr(cls, attr) and getattr(cls, attr) is not getattr(A.Transform3D, attr, None) 42 | for attr in ["apply_to_mask"] 43 | ) 44 | if has_masks_method: 45 | targets.add(Targets.MASK) 46 | 47 | has_masks3d_method = any( 48 | hasattr(cls, attr) and getattr(cls, attr) is not getattr(A.Transform3D, attr, None) 49 | for attr in ["apply_to_mask3d"] 50 | ) 51 | if has_masks3d_method: 52 | targets.add(Targets.MASK3D) 53 | 54 | has_bboxes_method = any( 55 | hasattr(cls, attr) and getattr(cls, attr) is not getattr(A.Transform3D, attr, None) 56 | for attr in ["apply_to_bboxes"] 57 | ) 58 | if has_bboxes_method: 59 | targets.add(Targets.BBOXES) 60 | 61 | has_keypoints_method = any( 62 | hasattr(cls, attr) and getattr(cls, attr) is not getattr(A.Transform3D, attr, None) 63 | for attr in ["apply_to_keypoints"] 64 | ) 65 | if has_keypoints_method: 66 | targets.add(Targets.KEYPOINTS) 67 | 68 | return targets 69 | 70 | TRASNFORM_3D_TARGETS = { 71 | } 72 | 73 | str2target = { 74 | "mask3d": Targets.MASK3D, 75 | "volume": Targets.VOLUME, 76 | "keypoints": Targets.KEYPOINTS, 77 | } 78 | 79 | @pytest.mark.parametrize( 80 | ["augmentation_cls", "params"], 81 | get_3d_transforms(custom_arguments={ 82 | A.PadIfNeeded3D: {"min_zyx": (4, 250, 230), "position": "center", "fill": 0, "fill_mask": 0}, 83 | A.Pad3D: {"padding": 10}, 84 | A.RandomCrop3D: {"size": (2, 30, 30), "pad_if_needed": True}, 85 | A.CenterCrop3D: {"size": (2, 30, 30), "pad_if_needed": True}, 86 | }) 87 | ) 88 | def test_transform3d(augmentation_cls, params): 89 | aug = augmentation_cls(p=1, **params) 90 | assert set(aug._targets) == set(TRASNFORM_3D_TARGETS.get(augmentation_cls, {Targets.MASK3D, Targets.VOLUME, Targets.KEYPOINTS})) 91 | assert set(aug._targets) <= get_targets_from_methods(augmentation_cls) 92 | 93 | targets_from_docstring = {str2target[target] for target in extract_targets_from_docstring(augmentation_cls)} 94 | 95 | assert set(aug._targets) == targets_from_docstring 96 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import functools 4 | import inspect 5 | import random 6 | from io import StringIO 7 | 8 | import numpy as np 9 | 10 | import albumentations 11 | from tests.aug_definitions import AUGMENTATION_CLS_PARAMS 12 | 13 | 14 | def convert_2d_to_3d(arrays, num_channels=3): 15 | # Converts a 2D numpy array with shape (H, W) into a 3D array with shape (H, W, num_channels) 16 | # by repeating the existing values along the new axis. 17 | arrays = tuple(np.repeat(array[:, :, np.newaxis], repeats=num_channels, axis=2) for array in arrays) 18 | if len(arrays) == 1: 19 | return arrays[0] 20 | return arrays 21 | 22 | 23 | def convert_2d_to_target_format(arrays, target): 24 | if target == "mask": 25 | return arrays[0] if len(arrays) == 1 else arrays 26 | if target == "image": 27 | return convert_2d_to_3d(arrays, num_channels=3) 28 | if target == "image_4_channels": 29 | return convert_2d_to_3d(arrays, num_channels=4) 30 | 31 | raise ValueError(f"Unknown target {target}") 32 | 33 | 34 | class InMemoryFile(StringIO): 35 | def __init__(self, value, save_value, file): 36 | super().__init__(value) 37 | self.save_value = save_value 38 | self.file = file 39 | 40 | def close(self): 41 | self.save_value(self.getvalue(), self.file) 42 | super().close() 43 | 44 | 45 | class OpenMock: 46 | """Mocks the `open` built-in function. A call to the instance of OpenMock returns an in-memory file which is 47 | readable and writable. The actual in-memory file implementation should call the passed `save_value` method 48 | to save the file content in the cache when the file is being closed to preserve the file content. 49 | """ 50 | 51 | def __init__(self): 52 | self.values = {} 53 | 54 | def __call__(self, file, *args, **kwargs): 55 | value = self.values.get(file) 56 | return InMemoryFile(value, self.save_value, file) 57 | 58 | def save_value(self, value, file): 59 | self.values[file] = value 60 | 61 | 62 | def set_seed(seed): 63 | random.seed(seed) 64 | np.random.seed(seed) 65 | 66 | 67 | def get_all_valid_transforms(use_cache=False): 68 | """ 69 | Find all transforms that are children of BasicTransform or BaseCompose, 70 | and do not have DeprecationWarning or FutureWarning. 71 | 72 | Args: 73 | use_cache (bool): Whether to cache the results using lru_cache. Default: False 74 | """ 75 | if use_cache: 76 | return _get_all_valid_transforms_cached() 77 | return _get_all_valid_transforms() 78 | 79 | 80 | @functools.lru_cache(maxsize=None) 81 | def _get_all_valid_transforms_cached(): 82 | return _get_all_valid_transforms() 83 | 84 | 85 | def _get_all_valid_transforms(): 86 | valid_transforms = [] 87 | for _, cls in inspect.getmembers(albumentations): 88 | if not inspect.isclass(cls) or not issubclass(cls, (albumentations.BasicTransform, albumentations.BaseCompose)): 89 | continue 90 | 91 | valid_transforms.append(cls) 92 | return valid_transforms 93 | 94 | 95 | def get_filtered_transforms( 96 | base_classes, 97 | custom_arguments=None, 98 | except_augmentations=None, 99 | exclude_base_classes=None, 100 | ): 101 | custom_arguments = custom_arguments or {} 102 | except_augmentations = except_augmentations or set() 103 | exclude_base_classes = exclude_base_classes or () 104 | 105 | # Create a mapping of transform class to params from AUGMENTATION_CLS_PARAMS 106 | default_params = {} 107 | for transform_entry in AUGMENTATION_CLS_PARAMS: 108 | transform_cls = transform_entry[0] 109 | params = transform_entry[1] 110 | 111 | # Convert single dict to list for uniform handling 112 | if isinstance(params, dict): 113 | params = [params] 114 | 115 | if transform_cls not in default_params: 116 | default_params[transform_cls] = [] 117 | default_params[transform_cls].extend(params) 118 | 119 | result = [] 120 | for cls in get_all_valid_transforms(): 121 | # Skip checks... 122 | if cls in except_augmentations: 123 | continue 124 | if any(cls == i for i in base_classes): 125 | continue 126 | if exclude_base_classes and issubclass(cls, exclude_base_classes): 127 | continue 128 | if not issubclass(cls, base_classes): 129 | continue 130 | 131 | # Get parameters for this transform 132 | if cls in custom_arguments: 133 | params = custom_arguments[cls] 134 | if isinstance(params, dict): 135 | params = [params] 136 | for param_set in params: 137 | result.append((cls, param_set)) 138 | elif cls in default_params: 139 | for param_set in default_params[cls]: 140 | result.append((cls, param_set)) 141 | else: 142 | result.append((cls, {})) 143 | 144 | return result 145 | 146 | 147 | def get_image_only_transforms( 148 | custom_arguments: dict[type[albumentations.ImageOnlyTransform], dict] | None = None, 149 | except_augmentations: set[type[albumentations.ImageOnlyTransform]] | None = None, 150 | ) -> list[tuple[type, dict]]: 151 | return get_filtered_transforms((albumentations.ImageOnlyTransform,), custom_arguments, except_augmentations) 152 | 153 | 154 | def get_dual_transforms( 155 | custom_arguments: dict[type[albumentations.DualTransform], dict] | None = None, 156 | except_augmentations: set[type[albumentations.DualTransform]] | None = None, 157 | ) -> list[tuple[type, dict]]: 158 | """Get all 2D dual transforms, excluding 3D transforms.""" 159 | return get_filtered_transforms( 160 | base_classes=(albumentations.DualTransform,), 161 | custom_arguments=custom_arguments, 162 | except_augmentations=except_augmentations, 163 | exclude_base_classes=(albumentations.Transform3D,) 164 | ) 165 | 166 | def get_transforms( 167 | custom_arguments: dict[type[albumentations.BasicTransform], dict] | None = None, 168 | except_augmentations: set[type[albumentations.BasicTransform]] | None = None, 169 | ) -> list[tuple[type, dict]]: 170 | """Get all transforms (2D and 3D).""" 171 | return get_filtered_transforms( 172 | base_classes=(albumentations.ImageOnlyTransform, albumentations.DualTransform, albumentations.Transform3D), 173 | custom_arguments=custom_arguments, 174 | except_augmentations=except_augmentations, 175 | ) 176 | 177 | def get_2d_transforms( 178 | custom_arguments: dict[type[albumentations.BasicTransform], dict] | None = None, 179 | except_augmentations: set[type[albumentations.BasicTransform]] | None = None, 180 | ) -> list[tuple[type, dict]]: 181 | """Get all 2D transforms (both ImageOnly and Dual transforms), excluding 3D transforms.""" 182 | return get_filtered_transforms( 183 | base_classes=(albumentations.ImageOnlyTransform, albumentations.DualTransform), 184 | custom_arguments=custom_arguments, 185 | except_augmentations=except_augmentations, 186 | exclude_base_classes=(albumentations.Transform3D,) # Exclude Transform3D and its children 187 | ) 188 | 189 | def check_all_augs_exists( 190 | augmentations: list[list], 191 | except_augmentations: set | None = None, 192 | ) -> list[tuple[type, dict]]: 193 | existed_augs = {i[0] for i in augmentations} 194 | except_augmentations = except_augmentations or set() 195 | 196 | not_existed = [] 197 | 198 | for cls, _ in get_transforms(except_augmentations=except_augmentations): 199 | if cls not in existed_augs: 200 | not_existed.append(cls.__name__) 201 | 202 | if not_existed: 203 | raise ValueError(f"These augmentations do not exist in augmentations and except_augmentations: {not_existed}") 204 | 205 | # Flatten the parameter sets into individual test cases 206 | flattened_augmentations = [] 207 | for aug_cls, params in augmentations: 208 | if isinstance(params, list): 209 | # If params is a list, create a test case for each parameter set 210 | for param_set in params: 211 | flattened_augmentations.append((aug_cls, param_set)) 212 | else: 213 | # If params is a single dict, keep as is 214 | flattened_augmentations.append((aug_cls, params)) 215 | 216 | return flattened_augmentations 217 | 218 | 219 | def get_3d_transforms( 220 | custom_arguments: dict[type[albumentations.Transform3D], dict] | None = None, 221 | except_augmentations: set[type[albumentations.Transform3D]] | None = None, 222 | ) -> list[tuple[type, dict]]: 223 | """Get all 3D transforms.""" 224 | return get_filtered_transforms( 225 | base_classes=(albumentations.Transform3D,), 226 | custom_arguments=custom_arguments, 227 | except_augmentations=except_augmentations, 228 | ) 229 | -------------------------------------------------------------------------------- /tools/check_albucore_version.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | import sys 5 | from pathlib import Path 6 | 7 | 8 | def check_albucore_version(filename: str) -> int: 9 | with Path(filename).open() as file: 10 | content = file.read() 11 | 12 | # Look for albucore in INSTALL_REQUIRES 13 | match = re.search(r'("albucore[^"]*")', content) 14 | if not match: 15 | print(f"Error: albucore not found in {filename}") 16 | return 1 17 | 18 | albucore_req = match[1] 19 | if not re.match(r'"albucore==\d+\.\d+\.\d+"', albucore_req): 20 | print(f"Error: albucore version must be exact (==) in {filename}. Found: {albucore_req}") 21 | return 1 22 | 23 | return 0 24 | 25 | 26 | if __name__ == "__main__": 27 | sys.exit(check_albucore_version("setup.py")) 28 | -------------------------------------------------------------------------------- /tools/check_defaults.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import inspect 4 | import sys 5 | 6 | import albumentations 7 | from albumentations.core.transforms_interface import BasicTransform 8 | 9 | 10 | def check_apply_methods(cls): 11 | """Check for issues in 'apply' methods related to default arguments and Optional type annotations.""" 12 | issues = [] 13 | for name, method in inspect.getmembers(cls, predicate=inspect.isfunction): 14 | if name.startswith("apply"): 15 | signature = inspect.signature(method) 16 | issues.extend( 17 | f"Default argument found in {cls.__name__}.{name} for parameter " 18 | f"{param.name} with default value {param.default}" 19 | for param in signature.parameters.values() 20 | if param.default is not inspect.Parameter.empty 21 | ) 22 | return issues 23 | 24 | 25 | def is_subclass_of_basic_transform(cls): 26 | """Check if a given class is a subclass of BasicTransform, excluding BasicTransform itself.""" 27 | return issubclass(cls, BasicTransform) and cls is not BasicTransform 28 | 29 | 30 | def main(): 31 | issues = [] 32 | # Check all classes in the albumentations module 33 | for _name, cls in inspect.getmembers(albumentations, predicate=inspect.isclass): 34 | if is_subclass_of_basic_transform(cls): 35 | issues.extend(check_apply_methods(cls)) 36 | 37 | if issues: 38 | print("\n".join(issues)) 39 | sys.exit(1) # Exit with error status 1 if there are any issues 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /tools/check_docstrings.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | import sys 5 | from pathlib import Path 6 | 7 | 8 | def check_docstrings_for_dashes(file_path: str) -> bool: 9 | pattern = re.compile(r'["\']{3}[\s\S]+?["\']{3}') # Regex to match docstrings 10 | dash_pattern = re.compile(r"---{2,}") # Regex to match sequences of --- 11 | 12 | with Path(file_path).open(encoding="utf-8") as file: 13 | content = file.read() 14 | matches = pattern.findall(content) 15 | for match in matches: 16 | if dash_pattern.search(match): 17 | return False # Found forbidden sequence 18 | return True # No forbidden sequences found 19 | 20 | 21 | def main(): 22 | exit_code = 0 23 | for file_path in sys.argv[1:]: 24 | if not check_docstrings_for_dashes(file_path): 25 | print( 26 | f"Error in {file_path}: According to Google Style docstrings, '---' should not be used " 27 | "to underline sections. Please refer to " 28 | "https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings", 29 | ) 30 | exit_code = 1 31 | sys.exit(exit_code) 32 | 33 | 34 | if __name__ == "__main__": 35 | main() 36 | -------------------------------------------------------------------------------- /tools/check_example_docstrings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from __future__ import annotations 4 | 5 | import ast 6 | import importlib.util 7 | import inspect 8 | import sys 9 | from pathlib import Path 10 | 11 | from google_docstring_parser import parse_google_docstring 12 | 13 | TARGET_PARENT_CLASSES = {"DualTransform", "ImageOnlyTransform", "Transform3D"} 14 | 15 | # We'll check for both, but have different error messages 16 | EXAMPLES_SECTION = "Examples" 17 | EXAMPLE_SECTION = "Example" 18 | 19 | 20 | def is_target_class(cls) -> bool: 21 | """Check if a class inherits from one of the target classes.""" 22 | # Skip if the class itself is one of the target classes 23 | if cls.__name__ in TARGET_PARENT_CLASSES: 24 | return False 25 | 26 | # Get all base classes in the class's MRO (Method Resolution Order) 27 | try: 28 | bases = [base.__name__ for base in inspect.getmro(cls)] 29 | # Class should inherit from target classes but not be one itself 30 | return any(base in TARGET_PARENT_CLASSES for base in bases) 31 | except TypeError: 32 | # If we can't get the MRO, we'll use the AST-based approach 33 | # This will be handled in check_file 34 | return False 35 | 36 | 37 | def build_inheritance_map(file_path: str) -> dict[str, list[str]]: 38 | """Build a map of class names to their direct parent class names.""" 39 | with Path(file_path).open(encoding="utf-8") as f: 40 | tree = ast.parse(f.read()) 41 | 42 | inheritance_map = {} 43 | 44 | for node in ast.walk(tree): 45 | if isinstance(node, ast.ClassDef): 46 | # Get direct parent class names 47 | parent_names = [base.id for base in node.bases if isinstance(base, ast.Name)] 48 | inheritance_map[node.name] = parent_names 49 | 50 | return inheritance_map 51 | 52 | 53 | def has_target_ancestor( 54 | class_name: str, 55 | inheritance_map: dict[str, list[str]], 56 | visited: set[str] | None = None, 57 | ) -> bool: 58 | """Recursively check if a class has any target class in its ancestry.""" 59 | if visited is None: 60 | visited = set() 61 | 62 | # Avoid cycles in inheritance 63 | if class_name in visited: 64 | return False 65 | visited.add(class_name) 66 | 67 | # Base case: this is a target class 68 | if class_name in TARGET_PARENT_CLASSES: 69 | return True 70 | 71 | # Get direct parents 72 | parents = inheritance_map.get(class_name, []) 73 | 74 | return any(has_target_ancestor(parent, inheritance_map, visited) for parent in parents) 75 | 76 | 77 | def check_docstring(docstring: str, class_name: str) -> list[tuple[str, str]]: 78 | """Check the docstring for a proper Examples section.""" 79 | errors = [] 80 | 81 | if not docstring: 82 | errors.append((class_name, "Missing docstring")) 83 | return errors 84 | 85 | try: 86 | parsed = parse_google_docstring(docstring) 87 | 88 | # First check if 'Example' is used instead of 'Examples' 89 | if EXAMPLE_SECTION in parsed and EXAMPLES_SECTION not in parsed: 90 | errors.append((class_name, f"Using '{EXAMPLE_SECTION}' instead of '{EXAMPLES_SECTION}' - use plural form")) 91 | # Then check if neither is present 92 | elif all(section not in parsed for section in [EXAMPLES_SECTION, EXAMPLE_SECTION]): 93 | errors.append((class_name, f"Missing '{EXAMPLES_SECTION}' section in docstring")) 94 | except (ValueError, AttributeError, TypeError) as e: 95 | errors.append((class_name, f"Error parsing docstring: {e!s}")) 96 | 97 | return errors 98 | 99 | 100 | def check_file(file_path: str) -> list[tuple[str, str]]: 101 | """Check a file for classes that need examples in their docstrings.""" 102 | errors = [] 103 | 104 | try: 105 | # Try to import the module 106 | module_name = Path(file_path).stem 107 | spec = importlib.util.spec_from_file_location(module_name, file_path) 108 | if not spec or not spec.loader: 109 | return [] 110 | 111 | module = importlib.util.module_from_spec(spec) 112 | spec.loader.exec_module(module) 113 | 114 | # Find all classes in the module 115 | for _, obj in inspect.getmembers(module): 116 | if inspect.isclass(obj) and obj.__module__ == module.__name__ and is_target_class(obj): 117 | docstring = inspect.getdoc(obj) 118 | errors.extend(check_docstring(docstring, obj.__name__)) 119 | except (ImportError, AttributeError, ModuleNotFoundError, SyntaxError): 120 | # If module import fails, use AST to check 121 | with Path(file_path).open(encoding="utf-8") as f: 122 | content = f.read() 123 | 124 | tree = ast.parse(content) 125 | inheritance_map = build_inheritance_map(file_path) 126 | 127 | # Find all class definitions 128 | for node in ast.walk(tree): 129 | if isinstance(node, ast.ClassDef): 130 | # Skip if class itself is a target class 131 | if node.name in TARGET_PARENT_CLASSES: 132 | continue 133 | 134 | # Check if this class has a target class in its ancestry 135 | if has_target_ancestor(node.name, inheritance_map): 136 | docstring = ast.get_docstring(node) 137 | errors.extend(check_docstring(docstring, node.name)) 138 | 139 | return errors 140 | 141 | 142 | def main(): 143 | """Main function for the pre-commit hook.""" 144 | files = sys.argv[1:] if len(sys.argv) > 1 else [] 145 | has_errors = False 146 | all_errors = [] 147 | 148 | for file_path in files: 149 | if not file_path.endswith(".py"): 150 | continue 151 | 152 | errors = check_file(file_path) 153 | if errors: 154 | has_errors = True 155 | all_errors.append((file_path, errors)) 156 | 157 | # Print all errors 158 | if all_errors: 159 | for file_path, errors in all_errors: 160 | file_rel_path = file_path.replace(str(Path.cwd()) + "/", "") 161 | print(f"\n{file_rel_path}:") 162 | for class_name, message in errors: 163 | print(f" - {class_name}: {message}") 164 | 165 | return 1 if has_errors else 0 166 | 167 | 168 | if __name__ == "__main__": 169 | sys.exit(main()) 170 | -------------------------------------------------------------------------------- /tools/check_naming_conflicts.py: -------------------------------------------------------------------------------- 1 | """Check for naming conflicts between modules and exported names. 2 | 3 | This script detects conflicts between module names and defined function/class names 4 | that could cause problems with frameworks like Hydra that rely on direct module paths. 5 | 6 | Example usage: 7 | python -m tools.check_naming_conflicts 8 | """ 9 | 10 | from __future__ import annotations 11 | 12 | import ast 13 | import os 14 | import sys 15 | from pathlib import Path 16 | 17 | 18 | def get_module_names(base_dir: str) -> set[str]: 19 | """Find all submodule names (directories with __init__.py). 20 | 21 | Args: 22 | base_dir (str): Base directory to search within 23 | 24 | Returns: 25 | Set[str]: Set of module names (not paths, just the directory names) 26 | 27 | """ 28 | module_names = set() 29 | for dirpath, _, filenames in os.walk(base_dir): 30 | if "__init__.py" in filenames: 31 | # We want the last component of the path as the module name 32 | module_name = Path(dirpath).name 33 | if dirpath != base_dir: # Skip root dir 34 | module_names.add(module_name) 35 | return module_names 36 | 37 | 38 | def _extract_names_from_node(node) -> set[str]: 39 | """Extract names from AST nodes.""" 40 | names = set() 41 | # Classes and functions 42 | if isinstance(node, (ast.ClassDef, ast.FunctionDef)): 43 | if not node.name.startswith("_"): 44 | names.add(node.name) 45 | # Variables 46 | elif isinstance(node, ast.Assign): 47 | for target in node.targets: 48 | if isinstance(target, ast.Name) and not target.id.startswith("_"): 49 | names.add(target.id) 50 | return names 51 | 52 | 53 | def get_defined_names(base_dir: str) -> set[str]: 54 | """Find all top-level names defined in Python files that would be exported with *. 55 | 56 | Args: 57 | base_dir (str): Base directory to search within 58 | 59 | Returns: 60 | Set[str]: Set of defined names that could be exported via wildcard imports 61 | 62 | """ 63 | defined_names = set() 64 | 65 | for root, _, files in os.walk(base_dir): 66 | for file in files: 67 | if not file.endswith(".py"): 68 | continue 69 | 70 | filepath = Path(root) / file 71 | 72 | try: 73 | with filepath.open(encoding="utf-8") as f: 74 | file_content = f.read() 75 | 76 | tree = ast.parse(file_content, filepath) 77 | 78 | # Look for top-level definitions that don't start with underscore 79 | for node in ast.iter_child_nodes(tree): 80 | defined_names.update(_extract_names_from_node(node)) 81 | except (SyntaxError, UnicodeDecodeError, IsADirectoryError) as e: 82 | print(f"Error parsing {filepath}: {e}", file=sys.stderr) 83 | 84 | return defined_names 85 | 86 | 87 | def find_conflicts(base_dir: str = "albumentations") -> tuple[set[str], set[str], set[str]]: 88 | """Find conflicts between module names and defined names. 89 | 90 | Args: 91 | base_dir (str): Base directory to check 92 | 93 | Returns: 94 | Tuple[Set[str], Set[str], Set[str]]: Tuple containing (module_names, defined_names, conflicts) 95 | 96 | """ 97 | module_names = get_module_names(base_dir) 98 | defined_names = get_defined_names(base_dir) 99 | 100 | conflicts = module_names.intersection(defined_names) 101 | 102 | return module_names, defined_names, conflicts 103 | 104 | 105 | def main(): 106 | """Main entry point for the script.""" 107 | base_dir = "albumentations" 108 | 109 | # Check if base directory exists 110 | if not Path(base_dir).is_dir(): 111 | print(f"Error: Directory '{base_dir}' not found.", file=sys.stderr) 112 | sys.exit(1) 113 | 114 | _, _, conflicts = find_conflicts(base_dir) 115 | 116 | if conflicts: 117 | print("⚠️ Naming conflicts detected between modules and defined names:", file=sys.stderr) 118 | for conflict in sorted(conflicts): 119 | print(f" - '{conflict}' is both a module name and a function/class", file=sys.stderr) 120 | print("\nThese conflicts can cause problems with tools like Hydra that use direct module paths.") 121 | print("Consider renaming either the module or the function/class.") 122 | sys.exit(1) 123 | 124 | print("✅ No naming conflicts detected between modules and defined names.") 125 | sys.exit(0) 126 | 127 | 128 | if __name__ == "__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /tools/check_no_defaults_in_schemas.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Pre-commit hook to check that classes inheriting from BaseModel (like InitSchema) 3 | do not have default values in their field definitions. 4 | 5 | This enforces the albumentations coding guideline: 6 | "We do not have ANY default values in the InitSchema class" 7 | """ 8 | 9 | from __future__ import annotations 10 | 11 | import argparse 12 | import ast 13 | import sys 14 | from pathlib import Path 15 | 16 | 17 | class DefaultValueChecker(ast.NodeVisitor): 18 | def __init__(self): 19 | self.errors: list[tuple[str, int, str]] = [] 20 | self.current_file = "" 21 | self.basemodel_classes: set[str] = set() 22 | self.class_inheritance: dict[str, list[str]] = {} 23 | 24 | def visit_ClassDef(self, node: ast.ClassDef) -> None: # noqa: N802 25 | """Visit a class definition node to check for BaseModel inheritance.""" 26 | # Track class inheritance 27 | base_names = [] 28 | for base in node.bases: 29 | if isinstance(base, ast.Name): 30 | base_names.append(base.id) 31 | elif isinstance(base, ast.Attribute): 32 | # Handle cases like pydantic.BaseModel 33 | base_names.append(ast.unparse(base)) 34 | 35 | self.class_inheritance[node.name] = base_names 36 | 37 | # Check if this class inherits from BaseModel (directly or indirectly) 38 | if self._inherits_from_basemodel(node.name): 39 | self.basemodel_classes.add(node.name) 40 | self._check_class_fields(node) 41 | 42 | self.generic_visit(node) 43 | 44 | def _inherits_from_basemodel(self, class_name: str) -> bool: 45 | """Check if a class inherits from BaseModel directly or indirectly.""" 46 | if class_name not in self.class_inheritance: 47 | return False 48 | 49 | bases = self.class_inheritance[class_name] 50 | 51 | # Direct inheritance 52 | for base in bases: 53 | if base in ("BaseModel", "pydantic.BaseModel", "BaseTransformInitSchema"): 54 | return True 55 | 56 | # Indirect inheritance (recursive check) 57 | return any(base in self.class_inheritance and self._inherits_from_basemodel(base) for base in bases) 58 | 59 | def _check_class_fields(self, node: ast.ClassDef) -> None: # noqa: C901, PLR0912 60 | """Check for default values in class field annotations.""" 61 | for item in node.body: 62 | if isinstance(item, ast.AnnAssign) and item.value is not None: 63 | # This is an annotated assignment with a default value 64 | field_name = ast.unparse(item.target) if hasattr(ast, "unparse") else str(item.target) 65 | 66 | # Skip special fields that are allowed to have defaults 67 | if self._is_allowed_default_field(field_name): 68 | continue 69 | 70 | # Skip discriminator fields (Literal types used for Pydantic discriminated unions) 71 | if self._is_discriminator_field(item): 72 | continue 73 | 74 | # Check if it's a Field() call with default 75 | if isinstance(item.value, ast.Call): 76 | if isinstance(item.value.func, ast.Name) and item.value.func.id == "Field": 77 | # Check if Field() has a default parameter or positional arg 78 | has_default = False 79 | 80 | # Check positional arguments (first arg is default if present) 81 | if item.value.args: 82 | has_default = True 83 | 84 | # Check keyword arguments for 'default' 85 | for keyword in item.value.keywords: 86 | if keyword.arg == "default": 87 | has_default = True 88 | break 89 | 90 | if has_default: 91 | self.errors.append( 92 | ( 93 | self.current_file, 94 | item.lineno, 95 | f"Field '{field_name}' in BaseModel class '{node.name}' has a default value", 96 | ), 97 | ) 98 | else: 99 | # Direct assignment (not Field()) 100 | self.errors.append( 101 | ( 102 | self.current_file, 103 | item.lineno, 104 | f"Field '{field_name}' in BaseModel class '{node.name}' has a default value", 105 | ), 106 | ) 107 | 108 | elif isinstance(item, ast.Assign): 109 | # Handle regular assignments (var = value) 110 | for target in item.targets: 111 | if isinstance(target, ast.Name): 112 | field_name = target.id 113 | if not self._is_allowed_default_field(field_name): 114 | self.errors.append( 115 | ( 116 | self.current_file, 117 | item.lineno, 118 | f"Field '{field_name}' in BaseModel class '{node.name}' has a default value", 119 | ), 120 | ) 121 | 122 | def _is_allowed_default_field(self, field_name: str) -> bool: 123 | """Check if a field is allowed to have default values.""" 124 | # Allow private fields, class variables, and special methods 125 | if field_name.startswith("_"): 126 | return True 127 | 128 | # Allow specific field names that might legitimately have defaults 129 | allowed_fields = { 130 | "model_config", # Pydantic config 131 | "strict", # Core validation system field 132 | "__annotations__", 133 | "__module__", 134 | "__qualname__", 135 | } 136 | 137 | return field_name in allowed_fields 138 | 139 | def _is_discriminator_field(self, item: ast.AnnAssign) -> bool: 140 | """Check if this is a discriminator field for Pydantic discriminated unions.""" 141 | if not item.annotation: 142 | return False 143 | 144 | # Check if the annotation is a Literal type 145 | annotation_str = ast.unparse(item.annotation) if hasattr(ast, "unparse") else str(item.annotation) 146 | 147 | # Look for Literal["some_value"] pattern 148 | if "Literal[" in annotation_str and isinstance(item.value, ast.Constant) and isinstance(item.value.value, str): 149 | literal_value = item.value.value 150 | # Check if the literal value appears in the annotation 151 | if f'"{literal_value}"' in annotation_str or f"'{literal_value}'" in annotation_str: 152 | return True 153 | 154 | return False 155 | 156 | 157 | def check_file(file_path: Path) -> list[tuple[str, int, str]]: 158 | """Check a single Python file for default values in BaseModel classes.""" 159 | try: 160 | with file_path.open(encoding="utf-8") as f: 161 | content = f.read() 162 | 163 | tree = ast.parse(content, filename=str(file_path)) 164 | checker = DefaultValueChecker() 165 | checker.current_file = str(file_path) 166 | checker.visit(tree) 167 | 168 | except SyntaxError as e: 169 | print(f"Syntax error in {file_path}: {e}") 170 | return [] 171 | except (OSError, UnicodeDecodeError) as e: 172 | print(f"Error processing {file_path}: {e}") 173 | return [] 174 | else: 175 | return checker.errors 176 | 177 | 178 | def main() -> int: 179 | parser = argparse.ArgumentParser( 180 | description="Check that BaseModel classes don't have default values", 181 | ) 182 | parser.add_argument( 183 | "files", 184 | nargs="*", 185 | help="Python files to check", 186 | ) 187 | parser.add_argument( 188 | "--exclude-pattern", 189 | action="append", 190 | default=[], 191 | help="Exclude files matching this pattern", 192 | ) 193 | 194 | args = parser.parse_args() 195 | 196 | if not args.files: 197 | return 0 198 | 199 | all_errors = [] 200 | 201 | for file_path in args.files: 202 | path = Path(file_path) 203 | 204 | # Skip non-Python files 205 | if path.suffix != ".py": 206 | continue 207 | 208 | # Skip excluded patterns 209 | skip = False 210 | for pattern in args.exclude_pattern: 211 | if pattern in str(path): 212 | skip = True 213 | break 214 | if skip: 215 | continue 216 | 217 | errors = check_file(path) 218 | all_errors.extend(errors) 219 | 220 | # Report errors 221 | if all_errors: 222 | print("❌ Found default values in BaseModel classes:") 223 | for file_path, line_no, message in all_errors: 224 | print(f" {file_path}:{line_no}: {message}") 225 | return 1 226 | return 0 227 | 228 | 229 | if __name__ == "__main__": 230 | sys.exit(main()) 231 | --------------------------------------------------------------------------------