├── .codecov.yml ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── documentation.md │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md ├── actions │ ├── cache │ │ └── action.yml │ ├── pip-list │ │ └── action.yml │ ├── pkg-create │ │ └── action.yml │ ├── pkg-install │ │ └── action.yml │ └── unittesting │ │ └── action.yml ├── dependabot.yml ├── labeler-config.yml ├── markdown.links.config.json ├── mergify.yml ├── scripts │ ├── find-unused-caches.py │ └── find-unused-caches.txt ├── stale.yml └── workflows │ ├── check-docs.yml │ ├── check-md-links.yml │ ├── check-package.yml │ ├── check-precommit.yml │ ├── check-schema.yml │ ├── check-typing.yml │ ├── ci-cli.yml │ ├── ci-rtfd.yml │ ├── ci-scripts.yml │ ├── ci-testing.yml │ ├── ci-use-checks.yaml │ ├── cleanup-caches.yml │ ├── cron-clear-cache.yml │ ├── deploy-docs.yml │ ├── label-pr.yml │ └── release-pypi.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── docs ├── .build_docs.sh ├── .readthedocs.yaml ├── Makefile ├── make.bat └── source │ ├── _static │ ├── copybutton.js │ └── images │ │ ├── icon.svg │ │ ├── logo-large.svg │ │ ├── logo-small.svg │ │ ├── logo.png │ │ └── logo.svg │ ├── _templates │ └── theme_variables.jinja │ ├── conf.py │ ├── index.rst │ └── test-page.rst ├── pyproject.toml ├── requirements ├── _docs.txt ├── _tests.txt ├── cli.txt ├── core.txt ├── docs.txt ├── gha-package.txt ├── gha-schema.txt └── typing.txt ├── scripts ├── adjust-torch-versions.py ├── inject-selector-script.py └── run_standalone_tests.sh ├── setup.py ├── src └── lightning_utilities │ ├── __about__.py │ ├── __init__.py │ ├── cli │ ├── __init__.py │ ├── __main__.py │ └── dependencies.py │ ├── core │ ├── __init__.py │ ├── apply_func.py │ ├── enums.py │ ├── imports.py │ ├── inheritance.py │ ├── overrides.py │ └── rank_zero.py │ ├── docs │ ├── __init__.py │ ├── formatting.py │ └── retriever.py │ ├── install │ ├── __init__.py │ └── requirements.py │ ├── py.typed │ └── test │ ├── __init__.py │ └── warning.py └── tests ├── scripts ├── __init__.py └── test_adjust_torch_versions.py └── unittests ├── __init__.py ├── cli └── test_dependencies.py ├── conftest.py ├── core ├── test_apply_func.py ├── test_enums.py ├── test_imports.py ├── test_inheritance.py ├── test_overrides.py └── test_rank_zero.py ├── docs ├── __init__.py ├── test_formatting.py └── test_retriever.py ├── mocks.py └── test └── test_warnings.py /.codecov.yml: -------------------------------------------------------------------------------- 1 | # see https://docs.codecov.io/docs/codecov-yaml 2 | # Validation check: 3 | # $ curl --data-binary @.codecov.yml https://codecov.io/validate 4 | 5 | # https://docs.codecov.io/docs/codecovyml-reference 6 | codecov: 7 | bot: "codecov-io" 8 | strict_yaml_branch: "yaml-config" 9 | require_ci_to_pass: yes 10 | notify: 11 | # after_n_builds: 2 12 | wait_for_ci: yes 13 | 14 | coverage: 15 | precision: 0 # 2 = xx.xx%, 0 = xx% 16 | round: nearest # how coverage is rounded: down/up/nearest 17 | range: 40...100 # custom range of coverage colors from red -> yellow -> green 18 | status: 19 | # https://codecov.readme.io/v1.0/docs/commit-status 20 | project: 21 | default: 22 | informational: true 23 | target: 99% # specify the target coverage for each commit status 24 | threshold: 30% # allow this little decrease on project 25 | # https://github.com/codecov/support/wiki/Filtering-Branches 26 | # branches: main 27 | if_ci_failed: error 28 | # https://github.com/codecov/support/wiki/Patch-Status 29 | patch: 30 | default: 31 | informational: true 32 | target: 50% # specify the target "X%" coverage to hit 33 | # threshold: 50% # allow this much decrease on patch 34 | changes: false 35 | 36 | # https://docs.codecov.com/docs/github-checks#disabling-github-checks-patch-annotations 37 | github_checks: 38 | annotations: false 39 | 40 | parsers: 41 | gcov: 42 | branch_detection: 43 | conditional: true 44 | loop: true 45 | macro: false 46 | method: false 47 | javascript: 48 | enable_partials: false 49 | 50 | comment: 51 | layout: header, diff 52 | require_changes: false 53 | behavior: default # update if exists else create new 54 | # branches: * 55 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # This is a comment. 2 | # Each line is a file pattern followed by one or more owners. 3 | 4 | # These owners will be the default owners for everything in 5 | # the repo. Unless a later match takes precedence, 6 | # @global-owner1 and @global-owner2 will be requested for 7 | # review when someone opens a pull request. 8 | * @borda @ethanwharris @justusschock @tchaton 9 | 10 | # CI/CD and configs 11 | /.github/ @borda @ethanwharris @justusschock @tchaton 12 | *.yml @borda @ethanwharris @justusschock @tchaton 13 | 14 | # Docs 15 | /docs/ @borda 16 | /.github/*.md @borda 17 | /.github/ISSUE_TEMPLATE/ @borda 18 | 19 | /.github/CODEOWNERS @borda 20 | /setup.py @borda 21 | 22 | /src @borda @ethanwharris @justusschock @tchaton 23 | /tests/unittests @borda @ethanwharris @justusschock @tchaton 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug, help wanted 6 | assignees: '' 7 | --- 8 | 9 | ## 🐛 Bug 10 | 11 | 12 | 13 | ### To Reproduce 14 | 15 | Steps to reproduce the behavior... 16 | 17 | 18 | 19 |
20 | Code sample 21 | 22 | 24 | 25 |
26 | 27 | ### Expected behavior 28 | 29 | 30 | 31 | ### Additional context 32 | 33 |
34 | Environment details 35 | 36 | - PyTorch Version (e.g., 1.0): 37 | - OS (e.g., Linux): 38 | - How you installed PyTorch (`conda`, `pip`, source): 39 | - Build command you used (if compiling from source): 40 | - Python version: 41 | - CUDA/cuDNN version: 42 | - GPU models and configuration: 43 | - Any other relevant information: 44 | 45 |
46 | 47 | 48 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Typos and doc fixes 3 | about: Typos and doc fixes 4 | title: '' 5 | labels: documentation 6 | assignees: '' 7 | --- 8 | 9 | ## 📚 Documentation 10 | 11 | For typos or docs fixes, please go ahead and submit a PR (no need to open an issue). 12 | If you are not sure about the proper solution, please describe the issue here... 13 | 14 | Thanks! 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement, help wanted 6 | assignees: '' 7 | --- 8 | 9 | ## 🚀 Feature 10 | 11 | 12 | 13 | ### Motivation 14 | 15 | 16 | 17 | ### Alternatives 18 | 19 | 20 | 21 | ### Additional context 22 | 23 | 24 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 |
2 | Before submitting 3 | 4 | - [ ] Was this discussed/agreed via a Github issue? (no need for typos and docs improvements) 5 | - [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/lightning/blob/master/.github/CONTRIBUTING.md), Pull Request section? 6 | - Did you make sure to update the docs? 7 | - [ ] Did all existing and newly added tests pass locally? 8 | 9 |
10 | 11 | ## What does this PR do? 12 | 13 | Fixes # (issue). 14 | 15 | ## PR review 16 | 17 | Anyone in the community is free to review the PR once the tests have passed. 18 | If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged. 19 | 20 | 25 | -------------------------------------------------------------------------------- /.github/actions/cache/action.yml: -------------------------------------------------------------------------------- 1 | name: Complex caching 2 | description: some more complex caching - pip & conda 3 | 4 | inputs: 5 | python-version: 6 | description: Python version 7 | required: true 8 | requires: 9 | description: define oldest, latest, or an empty string 10 | required: false 11 | default: "" 12 | offset: 13 | description: some extra hash for pip cache 14 | required: false 15 | default: "" 16 | interval: 17 | description: cache hash reset interval in days 18 | required: false 19 | default: "7" 20 | 21 | runs: 22 | using: "composite" 23 | steps: 24 | - name: Determine caches 25 | id: cache_dirs 26 | run: echo "pip_dir=$(pip cache dir)" >> $GITHUB_OUTPUT 27 | shell: bash 28 | 29 | - name: Cache 💽 pip 30 | uses: actions/cache@v3 31 | with: 32 | path: ${{ steps.cache_dirs.outputs.pip_dir }} 33 | key: py${{ inputs.python-version }}-pip-${{ inputs.offset }}-${{ hashFiles('requirements.txt') }} 34 | restore-keys: py${{ inputs.python-version }}-pip-${{ inputs.offset }}- 35 | enableCrossOsArchive: true 36 | 37 | - name: Cache 💽 conda 38 | uses: actions/cache@v3 39 | if: runner.os == 'Linux' 40 | with: 41 | path: ~/conda_pkgs_dir 42 | key: py${{ inputs.python-version }}-conda-${{ inputs.offset }} 43 | restore-keys: py${{ inputs.python-version }}-conda-${{ inputs.offset }} 44 | -------------------------------------------------------------------------------- /.github/actions/pip-list/action.yml: -------------------------------------------------------------------------------- 1 | name: Print pip list 2 | description: Print pip list for quick access to environment information 3 | 4 | inputs: 5 | unfold: 6 | description: Whether to unfold the output of pip list 7 | required: false 8 | default: "false" 9 | 10 | runs: 11 | using: "composite" 12 | steps: 13 | - name: pip list 14 | run: | 15 | if [ "${{ inputs.unfold }}" = "true" ]; then 16 | echo '
' >> $GITHUB_STEP_SUMMARY 17 | else 18 | echo '
' >> $GITHUB_STEP_SUMMARY 19 | fi 20 | echo 'pip list' >> $GITHUB_STEP_SUMMARY 21 | echo '' >> $GITHUB_STEP_SUMMARY 22 | echo '```' >> $GITHUB_STEP_SUMMARY 23 | pip list >> $GITHUB_STEP_SUMMARY 24 | echo '```' >> $GITHUB_STEP_SUMMARY 25 | echo '' >> $GITHUB_STEP_SUMMARY 26 | echo '
' >> $GITHUB_STEP_SUMMARY 27 | pip list # also print to stdout 28 | shell: bash 29 | -------------------------------------------------------------------------------- /.github/actions/pkg-create/action.yml: -------------------------------------------------------------------------------- 1 | name: Create and check package 2 | description: building, checking the package 3 | 4 | runs: 5 | using: "composite" 6 | steps: 7 | - name: Create package 📦 8 | # python setup.py clean 9 | run: python -m build --verbose 10 | shell: bash 11 | 12 | - name: Check package 📦 13 | working-directory: dist 14 | run: | 15 | ls -lh . 16 | twine check * --strict 17 | shell: bash 18 | -------------------------------------------------------------------------------- /.github/actions/pkg-install/action.yml: -------------------------------------------------------------------------------- 1 | name: Install and check package 2 | description: installing and validation the package 3 | 4 | inputs: 5 | install-from: 6 | description: "Define if the package is from archive or wheel" 7 | required: true 8 | pkg-folder: 9 | description: "Unique name for collecting artifacts" 10 | required: false 11 | default: "pypi" 12 | pkg-extras: 13 | description: "optional extras which are needed to include also []" 14 | required: false 15 | default: "" 16 | pip-flags: 17 | description: "additional pip install flags" 18 | required: false 19 | default: "-f https://download.pytorch.org/whl/cpu/torch_stable.html" 20 | import-name: 21 | description: "Import name to test with after installation" 22 | required: true 23 | custom-import-code: 24 | description: "additional import statement, need to be full python code" 25 | required: false 26 | default: "" 27 | 28 | runs: 29 | using: "composite" 30 | steps: 31 | - name: show packages 32 | working-directory: ${{ inputs.pkg-folder }} 33 | run: | 34 | ls -lh 35 | pip -V 36 | echo "PKG_WHEEL=$(ls *.whl | head -n1)" >> $GITHUB_ENV 37 | echo "PKG_SOURCE=$(ls *.tar.gz | head -n1)" >> $GITHUB_ENV 38 | pip list 39 | shell: bash 40 | 41 | - name: Install package (archive) 42 | if: ${{ inputs.install-from == 'archive' }} 43 | working-directory: pypi/ 44 | run: | 45 | set -ex 46 | pip install '${{ env.PKG_SOURCE }}${{ inputs.pkg-extras }}' \ 47 | --force-reinstall ${{ inputs.pip-flags }} 48 | pip list 49 | shell: bash 50 | 51 | - name: Install package (wheel) 52 | if: ${{ inputs.install-from == 'wheel' }} 53 | working-directory: ${{ inputs.pkg-folder }} 54 | run: | 55 | set -ex 56 | pip install '${{ env.PKG_WHEEL }}${{ inputs.pkg-extras }}' \ 57 | --force-reinstall ${{ inputs.pip-flags }} 58 | pip list 59 | shell: bash 60 | 61 | - name: package check / import 62 | if: ${{ inputs.import-name != '' }} 63 | run: | 64 | python -c "import ${{ inputs.import-name }} as pkg; print(f'version: {pkg.__version__}')" 65 | shell: bash 66 | 67 | - name: package check / custom import 68 | if: ${{ inputs.custom-import-code != '' }} 69 | run: | 70 | python -c '${{ inputs.custom-import-code }}' 71 | shell: bash 72 | 73 | - name: Uninstall all 74 | # TODO: reset env / consider add as conda 75 | run: | 76 | pip freeze > _reqs.txt 77 | pip uninstall -y -r _reqs.txt 78 | shell: bash 79 | -------------------------------------------------------------------------------- /.github/actions/unittesting/action.yml: -------------------------------------------------------------------------------- 1 | name: Unittest and coverage 2 | description: pull data samples -> unittests 3 | 4 | inputs: 5 | python-version: 6 | description: Python version 7 | required: true 8 | pkg-name: 9 | description: package name for coverage collections 10 | required: true 11 | requires: 12 | description: define oldest or latest 13 | required: false 14 | default: "" 15 | dirs: 16 | description: Testing folders per domains, space separated string 17 | required: false 18 | default: "." 19 | pytest-args: 20 | description: Additional pytest arguments such as `--timeout=120` 21 | required: false 22 | default: "" 23 | shell-type: 24 | description: Define Shell type 25 | required: false 26 | default: "bash" 27 | 28 | runs: 29 | using: "composite" 30 | steps: 31 | - name: Python 🐍 details 32 | run: | 33 | python --version 34 | pip --version 35 | pip list 36 | shell: ${{ inputs.shell-type }} 37 | 38 | - name: Determine artifact file name 39 | run: echo "artifact=test-results-${{ runner.os }}-py${{ inputs.python-version }}-${{ inputs.requires }}" >> $GITHUB_OUTPUT 40 | id: location 41 | shell: bash 42 | 43 | - name: Unittests 44 | working-directory: ./tests 45 | run: | 46 | python -m pytest ${{ inputs.dirs }} \ 47 | --cov=${{ inputs.pkg-name }} --durations=50 ${{ inputs.test-timeout }} \ 48 | --junitxml="${{ steps.location.outputs.artifact }}.xml" 49 | shell: ${{ inputs.shell-type }} 50 | 51 | - name: Upload pytest results 52 | uses: actions/upload-artifact@v4 53 | with: 54 | name: ${{ steps.location.outputs.artifact }} 55 | path: "test/${{ steps.location.outputs.artifact }}.xml" 56 | include-hidden-files: true 57 | if: failure() 58 | 59 | - name: Statistics 60 | if: success() 61 | working-directory: ./tests 62 | run: | 63 | coverage xml 64 | coverage report 65 | shell: ${{ inputs.shell-type }} 66 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 2 | version: 2 3 | updates: 4 | # Enable version updates for python 5 | - package-ecosystem: "pip" 6 | directory: "/requirements" 7 | schedule: 8 | interval: "weekly" 9 | labels: ["requirements", "enhancement"] 10 | pull-request-branch-name: 11 | separator: "-" 12 | open-pull-requests-limit: 5 13 | 14 | # Enable version updates for GitHub Actions 15 | - package-ecosystem: "github-actions" 16 | directory: "/" 17 | schedule: 18 | interval: "weekly" 19 | labels: ["ci/cd", "enhancement"] 20 | pull-request-branch-name: 21 | separator: "-" 22 | open-pull-requests-limit: 5 23 | -------------------------------------------------------------------------------- /.github/labeler-config.yml: -------------------------------------------------------------------------------- 1 | documentation: 2 | - changed-files: 3 | - any-glob-to-any-file: 4 | - docs/**/* 5 | - README.md 6 | - requirements/_docs.txt 7 | 8 | ci/cd: 9 | - changed-files: 10 | - any-glob-to-any-file: 11 | - .github/actions/**/* 12 | - .github/scripts/* 13 | - .github/workflows/* 14 | - .pre-commit-config.yaml 15 | 16 | package: 17 | - changed-files: 18 | - any-glob-to-any-file: 19 | - src/**/* 20 | - MANIFEST.in 21 | - pyproject.toml 22 | - setup.py 23 | 24 | tests: 25 | - changed-files: 26 | - any-glob-to-any-file: 27 | - tests/**/* 28 | - requirements/_tests.txt 29 | 30 | dependencies: 31 | - changed-files: 32 | - any-glob-to-any-file: 33 | - requirements/* 34 | 35 | release: 36 | - base-branch: "release/*" 37 | -------------------------------------------------------------------------------- /.github/markdown.links.config.json: -------------------------------------------------------------------------------- 1 | { 2 | "ignorePatterns": [ 3 | { 4 | "pattern": "^https://github.com/Lightning-AI/lightning/pull/.*" 5 | } 6 | ], 7 | "httpHeaders": [ 8 | { 9 | "urls": [ 10 | "https://github.com/", 11 | "https://guides.github.com/", 12 | "https://help.github.com/", 13 | "https://docs.github.com/" 14 | ], 15 | "headers": { 16 | "Accept-Encoding": "zstd, br, gzip, deflate" 17 | } 18 | } 19 | ] 20 | } 21 | -------------------------------------------------------------------------------- /.github/mergify.yml: -------------------------------------------------------------------------------- 1 | pull_request_rules: 2 | - name: warn on conflicts 3 | conditions: 4 | - conflict 5 | - -draft # filter-out GH draft PRs 6 | - -label="has conflicts" 7 | actions: 8 | # comment: 9 | # message: This pull request is now in conflict... :( 10 | label: 11 | add: ["has conflicts"] 12 | 13 | - name: resolved conflicts 14 | conditions: 15 | - -conflict 16 | - label="has conflicts" 17 | - -draft # filter-out GH draft PRs 18 | - -merged # not merged yet 19 | - -closed 20 | actions: 21 | label: 22 | remove: ["has conflicts"] 23 | 24 | - name: add core reviewer 25 | conditions: 26 | # number of review approvals 27 | - "#approved-reviews-by<2" 28 | actions: 29 | request_reviews: 30 | users: 31 | - Borda 32 | -------------------------------------------------------------------------------- /.github/scripts/find-unused-caches.py: -------------------------------------------------------------------------------- 1 | """Script for filtering unused caches.""" 2 | 3 | import os 4 | from datetime import timedelta 5 | 6 | 7 | def fetch_all_caches(repository: str, token: str, per_page: int = 100, max_pages: int = 100) -> list[dict]: 8 | """Fetch list of al caches from a given repository. 9 | 10 | Args: 11 | repository: user / repo-name 12 | token: authentication token for GH API calls 13 | per_page: number of items per listing page 14 | max_pages: max number of listing pages 15 | 16 | """ 17 | import requests 18 | from pandas import Timestamp, to_datetime 19 | 20 | # Initialize variables for pagination 21 | all_caches = [] 22 | 23 | for page in range(max_pages): 24 | # Get a page of caches for the repository 25 | url = f"https://api.github.com/repos/{repository}/actions/caches?page={page + 1}&per_page={per_page}" 26 | headers = {"Authorization": f"token {token}"} 27 | response = requests.get(url, headers=headers, timeout=10).json() 28 | if "total_count" not in response: 29 | raise RuntimeError(response.get("message")) 30 | print(f"fetching page... {page} with {per_page} items of expected {response.get('total_count')}") 31 | caches = response.get("actions_caches", []) 32 | 33 | # Append the caches from this page to the overall list 34 | all_caches.extend(caches) 35 | 36 | # Check if there are more pages to retrieve 37 | if len(caches) < per_page: 38 | break 39 | 40 | # Iterate through all caches and list them 41 | if all_caches: 42 | current_date = Timestamp.now(tz="UTC") 43 | print(f"Caches {len(all_caches)} for {repository}:") 44 | for cache in all_caches: 45 | cache_key = cache["id"] 46 | created_at = to_datetime(cache["created_at"]) 47 | last_used_at = to_datetime(cache["last_accessed_at"]) 48 | cache["last_used_days"] = current_date - last_used_at 49 | age_used = cache["last_used_days"].round(freq="min") 50 | size = cache["size_in_bytes"] / (1024 * 1024) 51 | print( 52 | f"- Cache Key: {cache_key} |" 53 | f" Created At: {created_at.strftime('%Y-%m-%d %H:%M')} |" 54 | f" Used At: {last_used_at.strftime('%Y-%m-%d %H:%M')} [{age_used}] |" 55 | f" Size: {size:.2f} MB" 56 | ) 57 | else: 58 | print("No caches found for the repository.") 59 | return all_caches 60 | 61 | 62 | def main(repository: str, token: str, age_days: float = 7, output_file: str = "unused-cashes.txt") -> None: 63 | """Entry point for CLI. 64 | 65 | Args: 66 | repository: GitHub repository name in form `/` 67 | token: authentication token for making API calls 68 | age_days: filter all caches older than this age set in days 69 | output_file: path to a file for dumping list of cache's Id 70 | 71 | """ 72 | caches = fetch_all_caches(repository, token) 73 | 74 | delta_days = timedelta(days=age_days) 75 | old_caches = [str(cache["id"]) for cache in caches if cache["last_used_days"] > delta_days] 76 | print(f"found {len(old_caches)} old caches:\n {old_caches}") 77 | 78 | with open(output_file, "w", encoding="utf8") as fw: 79 | fw.write(os.linesep.join(old_caches)) 80 | 81 | 82 | if __name__ == "__main__": 83 | from jsonargparse import auto_cli, set_parsing_settings 84 | 85 | set_parsing_settings(parse_optionals_as_positionals=True) 86 | auto_cli(main) 87 | -------------------------------------------------------------------------------- /.github/scripts/find-unused-caches.txt: -------------------------------------------------------------------------------- 1 | # Requirements for running equally named script, 2 | # having it in extra file to prevent version discrepancy if hardcoded in workflow 3 | 4 | jsonargparse[signatures] >=4.38.0 5 | requests 6 | pandas 7 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # https://github.com/marketplace/stale 2 | 3 | # Number of days of inactivity before an issue becomes stale 4 | daysUntilStale: 60 5 | # Number of days of inactivity before a stale issue is closed 6 | daysUntilClose: 14 7 | # Issues with these labels will never be considered stale 8 | exemptLabels: 9 | - pinned 10 | - security 11 | # Label to use when marking an issue as stale 12 | staleLabel: won't fix 13 | # Comment to post when marking an issue as stale. Set to `false` to disable 14 | markComment: > 15 | This issue has been automatically marked as stale because it has not had 16 | recent activity. It will be closed if no further activity occurs. Thank you 17 | for your contributions. 18 | # Comment to post when closing a stale issue. Set to `false` to disable 19 | closeComment: false 20 | 21 | # Set to true to ignore issues in a project (defaults to false) 22 | exemptProjects: true 23 | # Set to true to ignore issues in a milestone (defaults to false) 24 | exemptMilestones: true 25 | # Set to true to ignore issues with an assignee (defaults to false) 26 | exemptAssignees: true 27 | -------------------------------------------------------------------------------- /.github/workflows/check-docs.yml: -------------------------------------------------------------------------------- 1 | name: Building docs 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | actions-ref: 7 | description: "Version of actions, normally the same as workflow" 8 | required: false 9 | type: string 10 | default: "" 11 | python-version: 12 | description: "Python version to use" 13 | default: "3.9" 14 | required: false 15 | type: string 16 | docs-dir: 17 | description: "Working directory to run make docs in" 18 | default: "./docs" 19 | required: false 20 | type: string 21 | timeout-minutes: 22 | description: "timeout-minutes for each job" 23 | default: 15 24 | required: false 25 | type: number 26 | requirements-file: 27 | description: "path to the requirement file" 28 | default: "requirements/docs.txt" 29 | required: false 30 | type: string 31 | env-vars: 32 | description: "custom environment variables in json format" 33 | required: false 34 | type: string 35 | default: | 36 | { 37 | "SPHINX_MOCK_REQUIREMENTS": 0, 38 | } 39 | make-target: 40 | description: "what test configs to run in json format" 41 | required: false 42 | type: string 43 | default: | 44 | ["html", "doctest", "linkcheck"] 45 | install-tex: 46 | description: "optional installing Texlive support - true|false" 47 | required: false 48 | type: string 49 | default: false 50 | 51 | defaults: 52 | run: 53 | shell: bash 54 | 55 | env: 56 | # just use CPU version since running on CPU machine 57 | TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html" 58 | # default 0 means to keep for the maximum time 59 | KEEP_DAYS: 0 60 | 61 | jobs: 62 | make-docs: 63 | runs-on: ubuntu-22.04 64 | env: ${{ fromJSON(inputs.env-vars) }} 65 | strategy: 66 | fail-fast: false 67 | matrix: 68 | target: ${{ fromJSON(inputs.make-target) }} 69 | steps: 70 | - name: Checkout 🛎 71 | uses: actions/checkout@v4 72 | with: 73 | submodules: recursive 74 | 75 | - name: Set up Python 🐍 ${{ inputs.python-version }} 76 | uses: actions/setup-python@v5 77 | with: 78 | python-version: ${{ inputs.python-version }} 79 | cache: "pip" 80 | 81 | - name: Install pandoc & texlive 82 | if: ${{ inputs.install-tex == 'true' }} 83 | timeout-minutes: 10 84 | run: | 85 | sudo apt-get update --fix-missing 86 | sudo apt-get install -y \ 87 | pandoc \ 88 | texlive-latex-extra \ 89 | dvipng \ 90 | texlive-pictures \ 91 | latexmk 92 | - name: Install dependencies 93 | timeout-minutes: 20 94 | run: | 95 | pip --version 96 | pip install -e . -U -r ${{ inputs.requirements-file }} -f ${TORCH_URL} 97 | pip list 98 | 99 | - name: Pull reusable 🤖 actions️ 100 | if: ${{ inputs.actions-ref != '' }} 101 | uses: actions/checkout@v4 102 | with: 103 | ref: ${{ inputs.actions-ref }} 104 | path: .cicd 105 | repository: Lightning-AI/utilities 106 | - name: Print 🖨️ dependencies 107 | if: ${{ inputs.actions-ref != '' }} 108 | uses: ./.cicd/.github/actions/pip-list 109 | 110 | - name: Build documentation 111 | working-directory: ${{ inputs.docs-dir }} 112 | run: | 113 | make ${{ matrix.target }} \ 114 | --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going" 115 | 116 | - name: Shorten keep artifact 117 | if: startsWith(github.event_name, 'pull_request') 118 | run: echo "KEEP_DAYS=7" >> $GITHUB_ENV 119 | - name: Upload built docs 120 | uses: actions/upload-artifact@v4 121 | with: 122 | name: docs-${{ matrix.target }}-${{ github.sha }} 123 | path: ${{ inputs.docs-dir }}/build/ 124 | retention-days: ${{ env.KEEP_DAYS }} 125 | include-hidden-files: true 126 | -------------------------------------------------------------------------------- /.github/workflows/check-md-links.yml: -------------------------------------------------------------------------------- 1 | name: Check Markdown links 2 | # https://github.com/gaurav-nelson/github-action-markdown-link-check 3 | 4 | on: 5 | workflow_call: 6 | inputs: 7 | base-branch: 8 | description: "Default branch name" 9 | required: true 10 | type: string 11 | default: "master" 12 | config-file: 13 | description: "Config file (JSON) used for markdown link checks" 14 | required: false 15 | type: string 16 | default: "" 17 | 18 | jobs: 19 | markdown-link-check: 20 | runs-on: ubuntu-24.04 21 | env: 22 | CONFIG_FILE: ${{ inputs.config-file }} 23 | MODIFIED_ONLY: "no" 24 | steps: 25 | - uses: actions/checkout@master 26 | 27 | - name: Create local version of config 28 | if: ${{ inputs.config-file == '' }} 29 | run: | 30 | echo '{ 31 | "ignorePatterns": [ 32 | { 33 | "pattern": "^https://github.com/${{ github.repository }}/pull/.*" 34 | } 35 | ], 36 | "httpHeaders": [ 37 | { 38 | "urls": ["https://github.com/", "https://guides.github.com/", "https://help.github.com/", "https://docs.github.com/"], 39 | "headers": { 40 | "Accept-Encoding": "zstd, br, gzip, deflate" 41 | } 42 | } 43 | ] 44 | }' > 'markdown.links.config.json' 45 | echo "CONFIG_FILE=markdown.links.config.json" >> $GITHUB_ENV 46 | cat 'markdown.links.config.json' 47 | - name: Show config 48 | run: cat ${{ env.CONFIG_FILE }} 49 | 50 | - name: narrow scope for PR 51 | if: startsWith(github.event_name, 'pull_request') 52 | run: echo "MODIFIED_ONLY=yes" >> $GITHUB_ENV 53 | - name: Checking markdown link 54 | uses: gaurav-nelson/github-action-markdown-link-check@v1 55 | with: 56 | base-branch: ${{ inputs.base-branch }} 57 | use-quiet-mode: "yes" 58 | check-modified-files-only: ${{ env.MODIFIED_ONLY }} 59 | config-file: ${{ env.CONFIG_FILE }} 60 | -------------------------------------------------------------------------------- /.github/workflows/check-package.yml: -------------------------------------------------------------------------------- 1 | name: Check package flow 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | actions-ref: 7 | description: "Version of actions, normally the same as workflow" 8 | required: true 9 | type: string 10 | artifact-name: 11 | description: "Unique name for collecting artifacts, it shall be unique for all workflows" 12 | required: true 13 | type: string 14 | install-extras: 15 | description: "optional extras which are needed to include also []" 16 | required: false 17 | type: string 18 | default: "" 19 | install-flags: 20 | description: "Additional pip install flags" 21 | required: false 22 | type: string 23 | default: "-f https://download.pytorch.org/whl/cpu/torch_stable.html" 24 | import-name: 25 | description: "Import name to test with after installation" 26 | required: true 27 | type: string 28 | custom-import-code: 29 | description: "additional import statement, need to be full python code" 30 | type: string 31 | required: false 32 | default: "" 33 | build-matrix: 34 | description: "what building configs in json format, expected keys are `os` and `python-version`" 35 | required: false 36 | type: string 37 | default: | 38 | { 39 | "os": ["ubuntu-latest"], 40 | } 41 | testing-matrix: 42 | description: "what test configs to run in json format, expected keys are `os` and `python-version`" 43 | required: false 44 | type: string 45 | # default operating systems should be pinned to specific versions instead of "-latest" for stability 46 | # https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners#supported-runners-and-hardware-resources 47 | default: | 48 | { 49 | "os": ["ubuntu-22.04", "macos-13", "windows-2022"], 50 | "python-version": ["3.9", "3.13"] 51 | } 52 | env-vars: 53 | description: "custom environment variables in json format" 54 | required: false 55 | type: string 56 | default: | 57 | { 58 | "SAMPLE_ENV_VARIABLE": 1, 59 | } 60 | 61 | defaults: 62 | run: 63 | shell: bash 64 | 65 | jobs: 66 | pkg-build: 67 | runs-on: ${{ matrix.os }} 68 | env: ${{ fromJSON(inputs.env-vars) }} 69 | strategy: 70 | fail-fast: false 71 | matrix: ${{ fromJSON(inputs.build-matrix) }} 72 | steps: 73 | - name: Checkout 🛎️ 74 | uses: actions/checkout@v4 75 | with: 76 | fetch-depth: 0 # checkout entire history for all branches (required when using scm-based versioning) 77 | submodules: recursive 78 | - name: Set up Python 🐍 79 | uses: actions/setup-python@v5 80 | with: 81 | python-version: ${{ matrix.python-version || '3.x' }} 82 | 83 | - name: Pull reusable 🤖 actions️ 84 | uses: actions/checkout@v4 85 | with: 86 | ref: ${{ inputs.actions-ref }} 87 | path: .cicd 88 | repository: Lightning-AI/utilities 89 | - name: Prepare build env. 90 | run: | 91 | pip install -q -r ./.cicd/requirements/gha-package.txt 92 | pip list 93 | - name: Create package 📦 94 | uses: ./.cicd/.github/actions/pkg-create 95 | - name: Upload 📤 packages 96 | if: ${{ inputs.artifact-name != '' }} 97 | uses: actions/upload-artifact@v4 98 | with: 99 | name: ${{ inputs.artifact-name }}-build-${{ strategy.job-index }} 100 | path: dist 101 | 102 | merge-artifacts: 103 | needs: pkg-build 104 | runs-on: ubuntu-latest 105 | steps: 106 | - name: Pull reusable 🤖 actions️ 107 | uses: actions/checkout@v4 108 | with: 109 | ref: ${{ inputs.actions-ref }} 110 | path: .cicd 111 | repository: Lightning-AI/utilities 112 | - name: Prepare build env. 113 | run: | 114 | pip install -q -r ./.cicd/requirements/gha-package.txt 115 | pip list 116 | 117 | - name: Download 📥 118 | uses: actions/download-artifact@v4 119 | with: 120 | # download all build artifacts 121 | pattern: ${{ inputs.artifact-name }}-build-* 122 | merge-multiple: true 123 | path: dist 124 | - name: Brief look 125 | run: | 126 | ls -lh dist/ 127 | twine check dist/* 128 | - name: Upload 📤 129 | uses: actions/upload-artifact@v4 130 | with: 131 | name: ${{ inputs.artifact-name }} 132 | path: dist 133 | 134 | pkg-check: 135 | needs: merge-artifacts 136 | runs-on: ${{ matrix.os }} 137 | env: ${{ fromJSON(inputs.env-vars) }} 138 | strategy: 139 | fail-fast: false 140 | matrix: ${{ fromJSON(inputs.testing-matrix) }} 141 | steps: 142 | - name: Checkout 🛎️ 143 | uses: actions/checkout@v4 144 | with: 145 | submodules: recursive 146 | - name: Set up Python 🐍 ${{ matrix.python-version }} 147 | uses: actions/setup-python@v5 148 | with: 149 | python-version: ${{ matrix.python-version || '3.x' }} 150 | 151 | - name: Pull reusable 🤖 actions️ 152 | uses: actions/checkout@v4 153 | with: 154 | ref: ${{ inputs.actions-ref }} 155 | path: .cicd 156 | repository: Lightning-AI/utilities 157 | - name: Download 📥 all packages 158 | if: ${{ inputs.artifact-name != '' }} 159 | uses: actions/download-artifact@v4 160 | with: 161 | name: ${{ inputs.artifact-name }} 162 | path: pypi 163 | - name: Installing package 📦 as Archive 164 | timeout-minutes: 10 165 | uses: ./.cicd/.github/actions/pkg-install 166 | with: 167 | install-from: "archive" 168 | pkg-extras: ${{ inputs.install-extras }} 169 | pip-flags: ${{ inputs.install-flags }} 170 | import-name: ${{ inputs.import-name }} 171 | custom-import-code: ${{ inputs.custom-import-code }} 172 | - name: Installing package 📦 as Wheel 173 | timeout-minutes: 10 174 | uses: ./.cicd/.github/actions/pkg-install 175 | with: 176 | install-from: "wheel" 177 | pkg-extras: ${{ inputs.install-extras }} 178 | pip-flags: ${{ inputs.install-flags }} 179 | import-name: ${{ inputs.import-name }} 180 | custom-import-code: ${{ inputs.custom-import-code }} 181 | 182 | # TODO: add run doctests 183 | 184 | pkg-guardian: 185 | runs-on: ubuntu-latest 186 | needs: pkg-check 187 | if: always() 188 | steps: 189 | - run: echo "${{ needs.pkg-check.result }}" 190 | - name: failing... 191 | if: needs.pkg-check.result == 'failure' 192 | run: exit 1 193 | - name: cancelled or skipped... 194 | if: contains(fromJSON('["cancelled", "skipped"]'), needs.pkg-check.result) 195 | timeout-minutes: 1 196 | run: sleep 90 197 | -------------------------------------------------------------------------------- /.github/workflows/check-precommit.yml: -------------------------------------------------------------------------------- 1 | name: Check formatting flow 2 | 3 | on: 4 | workflow_call: 5 | secrets: 6 | github-token: 7 | description: "if provided an user GH's token, it will push update; requires `push` event" 8 | required: false 9 | inputs: 10 | python-version: 11 | description: "Python version to use" 12 | default: "3.9" 13 | required: false 14 | type: string 15 | use-cache: 16 | description: "enable using GH caching for performance boost" 17 | type: boolean 18 | required: false 19 | default: true 20 | push-fixes: 21 | description: "if provided an user GH's token, it will push update; requires `push` event" 22 | type: boolean 23 | required: false 24 | default: false 25 | 26 | defaults: 27 | run: 28 | shell: bash 29 | 30 | jobs: 31 | pre-commit: 32 | runs-on: ubuntu-22.04 33 | steps: 34 | - name: Checkout 🛎️ 35 | uses: actions/checkout@v4 36 | with: 37 | fetch-depth: 0 38 | submodules: recursive 39 | token: ${{ secrets.github-token || github.token }} 40 | 41 | - name: Set up Python 🐍 42 | uses: actions/setup-python@v5 43 | with: 44 | python-version: ${{ inputs.python-version }} 45 | 46 | - name: Cache 💽 pre-commit 47 | if: ${{ inputs.use-cache == true }} 48 | uses: actions/cache@v4 49 | with: 50 | path: ~/.cache/pre-commit 51 | key: pre-commit|py${{ inputs.python-version }}|${{ hashFiles('.pre-commit-config.yaml') }} 52 | 53 | - name: Run pre-commit 🤖 54 | id: precommit 55 | run: | 56 | pip install -q pre-commit 57 | pre-commit run --all-files 58 | 59 | - name: Fixing Pull Request ↩️ 60 | if: always() && inputs.push-fixes == true && steps.precommit.outcome == 'failure' 61 | uses: actions-js/push@v1.5 62 | with: 63 | github_token: ${{ secrets.github-token || github.token }} 64 | message: "pre-commit: running and fixing..." 65 | branch: ${{ github.ref_name }} 66 | -------------------------------------------------------------------------------- /.github/workflows/check-schema.yml: -------------------------------------------------------------------------------- 1 | name: Check schema flow 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | actions-ref: 7 | description: "Version of actions, normally the same as workflow" 8 | required: false 9 | type: string 10 | default: "" 11 | azure-dir: 12 | description: "Directory containing Azure Pipelines config files. Provide an empty string to skip checking on Azure Pipelines files." 13 | default: ".azure/" 14 | required: false 15 | type: string 16 | azure-schema-version: 17 | description: "Version of Azure Pipelines schema to use. Provide an empty string to skip checking on Azure Pipelines files." 18 | default: "v1.208.0" 19 | required: false 20 | type: string 21 | 22 | defaults: 23 | run: 24 | shell: bash 25 | 26 | jobs: 27 | schema: 28 | runs-on: ubuntu-24.04 29 | steps: 30 | - name: Checkout 🛎 31 | uses: actions/checkout@v4 32 | with: 33 | submodules: recursive 34 | - name: Set up Python 35 | uses: actions/setup-python@v5 36 | with: 37 | python-version: "3.10" 38 | 39 | # if actions version is given install defined versions 40 | - name: "[optional] Pull reusable 🤖 actions" 41 | if: inputs.actions-ref != '' 42 | uses: actions/checkout@v4 43 | with: 44 | ref: ${{ inputs.actions-ref }} 45 | path: .cicd 46 | repository: Lightning-AI/utilities 47 | - name: "[optional] Install recommended dependencies" 48 | if: inputs.actions-ref != '' 49 | timeout-minutes: 5 50 | run: | 51 | pip install -r ./.cicd/requirements/gha-schema.txt 52 | pip list | grep "check-jsonschema" 53 | # otherwise fall back to using the latest 54 | - name: "[default] Install recommended dependencies" 55 | if: inputs.actions-ref == '' 56 | timeout-minutes: 5 57 | run: | 58 | pip install -q check-jsonschema 59 | pip list | grep "check-jsonschema" 60 | 61 | - name: Scan repo 62 | id: folders 63 | run: python -c "import os; print('gh_actions=' + str(int(os.path.isdir('.github/actions'))))" >> $GITHUB_OUTPUT 64 | 65 | # https://github.com/SchemaStore/schemastore/blob/master/src/schemas/json/github-workflow.json 66 | - name: GitHub Actions - workflow 67 | run: | 68 | files=$(find .github/workflows -name '*.yml' -or -name '*.yaml' -not -name '__*') 69 | for f in $files; do 70 | echo $f; 71 | check-jsonschema -v $f --builtin-schema "github-workflows"; 72 | done 73 | 74 | # https://github.com/SchemaStore/schemastore/blob/master/src/schemas/json/github-action.json 75 | - name: GitHub Actions - action 76 | if: steps.folders.outputs.gh_actions == '1' 77 | run: | 78 | files=$(find .github/actions -name '*.yml' -or -name '*.yaml') 79 | for f in $files; do 80 | echo $f; 81 | check-jsonschema -v $f --builtin-schema "github-actions"; 82 | done 83 | 84 | # https://github.com/microsoft/azure-pipelines-vscode/blob/main/service-schema.json 85 | - name: Azure Pipelines 86 | if: ${{ inputs.azure-dir != '' }} 87 | env: 88 | SCHEMA_FILE: https://raw.githubusercontent.com/microsoft/azure-pipelines-vscode/${{ inputs.azure-schema-version }}/service-schema.json 89 | run: | 90 | files=$(find ${{ inputs.azure-dir }} -name '*.yml' -or -name '*.yaml') 91 | for f in $files; do 92 | echo $f; 93 | check-jsonschema -v $f --schemafile "$SCHEMA_FILE" --regex-variant="nonunicode"; 94 | done 95 | -------------------------------------------------------------------------------- /.github/workflows/check-typing.yml: -------------------------------------------------------------------------------- 1 | name: Check formatting flow 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | actions-ref: 7 | description: "Version of actions, normally the same as workflow" 8 | required: false 9 | type: string 10 | default: "" 11 | python-version: 12 | description: "Python version to use" 13 | default: "3.10" 14 | required: false 15 | type: string 16 | source-dir: 17 | description: "Source directory to check" 18 | default: "src/" 19 | required: false 20 | type: string 21 | extra-typing: 22 | description: "Package extra to be installed for type checks + include mypy" 23 | default: "test" 24 | required: false 25 | type: string 26 | 27 | defaults: 28 | run: 29 | shell: bash 30 | 31 | jobs: 32 | mypy: 33 | runs-on: ubuntu-24.04 34 | steps: 35 | - name: Checkout 🛎️ 36 | uses: actions/checkout@v4 37 | with: 38 | submodules: recursive 39 | 40 | - name: Set up Python 🐍 ${{ inputs.python-version }} 41 | uses: actions/setup-python@v5 42 | with: 43 | python-version: ${{ inputs.python-version }} 44 | 45 | - name: Install dependencies 46 | timeout-minutes: 20 47 | run: | 48 | # don't use --upgrade to respect the version installed via setup.py 49 | pip install -e '.[${{ inputs.extra-typing }}]' mypy \ 50 | --extra-index-url https://download.pytorch.org/whl/cpu/torch_stable.html 51 | pip list 52 | 53 | - name: Pull reusable 🤖 actions️ 54 | if: ${{ inputs.actions-ref != '' }} 55 | uses: actions/checkout@v4 56 | with: 57 | ref: ${{ inputs.actions-ref }} 58 | path: .cicd 59 | repository: Lightning-AI/utilities 60 | - name: Print 🖨️ dependencies 61 | if: ${{ inputs.actions-ref != '' }} 62 | uses: ./.cicd/.github/actions/pip-list 63 | with: 64 | unfold: true 65 | 66 | # see: https://github.com/python/mypy/issues/10600#issuecomment-857351152 67 | - name: init mypy 68 | continue-on-error: true 69 | run: | 70 | mkdir -p .mypy_cache 71 | mypy --install-types --non-interactive . 72 | 73 | - name: Check typing 74 | # mypy uses the config file found in the following order: 75 | # 1. mypy.ini 76 | # 2. pyproject.toml 77 | # 3. setup.cfg 78 | # 4. $XDG_CONFIG_HOME/mypy/config 79 | # 5. ~/.config/mypy/config 80 | # 6. ~/.mypy.ini 81 | # https://mypy.readthedocs.io/en/stable/config_file.html 82 | run: mypy 83 | 84 | - name: suggest ignores 85 | if: failure() 86 | env: 87 | SOURCE_DIR: ${{ inputs.source-dir }} 88 | run: | 89 | mypy --no-error-summary 2>&1 \ 90 | | tr ':' ' ' \ 91 | | awk '{print $1}' \ 92 | | sort \ 93 | | uniq \ 94 | | sed 's/\.py//g; s|${SOURCE_DIR}||g; s|\/__init__||g; s|\/|\.|g' \ 95 | | xargs -I {} echo '"{}",' \ 96 | || true 97 | -------------------------------------------------------------------------------- /.github/workflows/ci-cli.yml: -------------------------------------------------------------------------------- 1 | name: Test CLI 2 | 3 | on: 4 | push: 5 | branches: [main, "release/*"] 6 | pull_request: 7 | branches: [main, "release/*"] 8 | 9 | defaults: 10 | run: 11 | shell: bash 12 | 13 | jobs: 14 | test-cli: 15 | runs-on: ${{ matrix.os }} 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | os: ["ubuntu-latest", "macos-latest", "windows-latest"] 20 | python-version: ["3.10"] 21 | timeout-minutes: 10 22 | steps: 23 | - name: Checkout 🛎 24 | uses: actions/checkout@v4 25 | - name: Set up Python 🐍 ${{ matrix.python-version }} 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | 30 | - name: install package 31 | run: | 32 | pip install -e '.[cli]' 33 | pip list 34 | 35 | - name: run CLI 36 | working-directory: ./requirements 37 | run: | 38 | python -m lightning_utilities.cli version 39 | python -m lightning_utilities.cli --help 40 | python -m lightning_utilities.cli requirements set-oldest --req_files="cli.txt" 41 | python -m lightning_utilities.cli requirements set-oldest --req_files='["cli.txt", "docs.txt"]' 42 | 43 | cli-guardian: 44 | runs-on: ubuntu-latest 45 | needs: test-cli 46 | if: always() 47 | steps: 48 | - run: echo "${{ needs.test-cli.result }}" 49 | - name: failing... 50 | if: needs.test-cli.result == 'failure' 51 | run: exit 1 52 | - name: cancelled or skipped... 53 | if: contains(fromJSON('["cancelled", "skipped"]'), needs.test-cli.result) 54 | timeout-minutes: 1 55 | run: sleep 90 56 | -------------------------------------------------------------------------------- /.github/workflows/ci-rtfd.yml: -------------------------------------------------------------------------------- 1 | name: RTFD Preview 2 | on: 3 | pull_request_target: 4 | types: 5 | - opened 6 | 7 | permissions: 8 | pull-requests: write 9 | 10 | jobs: 11 | documentation-links: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: readthedocs/actions/preview@v1 15 | with: 16 | project-slug: "lit-utilities" 17 | -------------------------------------------------------------------------------- /.github/workflows/ci-scripts.yml: -------------------------------------------------------------------------------- 1 | name: Test scripts 2 | 3 | on: 4 | push: 5 | branches: [main, "release/*"] 6 | pull_request: 7 | branches: [main, "release/*"] 8 | 9 | defaults: 10 | run: 11 | shell: bash 12 | 13 | jobs: 14 | test-scripts: 15 | runs-on: ${{ matrix.os }} 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | os: ["ubuntu-latest", "macos-latest", "windows-latest"] 20 | python-version: ["3.10"] 21 | timeout-minutes: 15 22 | steps: 23 | - name: Checkout 🛎 24 | uses: actions/checkout@v4 25 | - name: Set up Python 🐍 ${{ matrix.python-version }} 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | cache: "pip" 30 | 31 | - name: Install dependencies 32 | timeout-minutes: 5 33 | run: | 34 | pip install -r requirements/_tests.txt 35 | pip --version 36 | pip list 37 | 38 | - name: test Scripts 39 | working-directory: ./scripts 40 | run: pytest . -v 41 | 42 | standalone-run: 43 | runs-on: "ubuntu-22.04" 44 | timeout-minutes: 20 45 | env: 46 | TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html" 47 | steps: 48 | - name: Checkout 🛎 49 | uses: actions/checkout@v4 50 | - name: Set up Python 🐍 ${{ matrix.python-version }} 51 | uses: actions/setup-python@v5 52 | with: 53 | python-version: "3.10" 54 | cache: "pip" 55 | - name: Install dependencies 56 | timeout-minutes: 20 57 | run: | 58 | set -e 59 | pip install -e . -U -r requirements/_tests.txt -f $TORCH_URL 60 | pip --version 61 | pip list 62 | 63 | - name: Run standalone script 64 | run: bash ./scripts/run_standalone_tests.sh "tests" 65 | env: 66 | COVERAGE_SOURCE: "lightning_utilities" 67 | 68 | scripts-guardian: 69 | runs-on: ubuntu-latest 70 | needs: [test-scripts, standalone-run] 71 | if: always() 72 | steps: 73 | - run: echo "${{ needs.test-scripts.result }}" 74 | - name: failing... 75 | if: needs.test-scripts.result == 'failure' 76 | run: exit 1 77 | - name: cancelled or skipped... 78 | if: contains(fromJSON('["cancelled", "skipped"]'), needs.test-scripts.result) 79 | timeout-minutes: 1 80 | run: sleep 90 81 | -------------------------------------------------------------------------------- /.github/workflows/ci-testing.yml: -------------------------------------------------------------------------------- 1 | name: UnitTests 2 | 3 | on: 4 | push: 5 | branches: [main, "release/*"] 6 | pull_request: 7 | branches: [main, "release/*"] 8 | 9 | defaults: 10 | run: 11 | shell: bash 12 | 13 | jobs: 14 | pytester: 15 | runs-on: ${{ matrix.os }} 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | os: ["ubuntu-22.04", "macos-13", "windows-2022"] 20 | python-version: ["3.9", "3.11", "3.13"] 21 | requires: ["oldest", "latest"] 22 | exclude: 23 | - { requires: "oldest", python-version: "3.13" } 24 | timeout-minutes: 35 25 | env: 26 | TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html" 27 | steps: 28 | - name: Checkout 🛎 29 | uses: actions/checkout@v4 30 | with: 31 | submodules: recursive 32 | - name: Set up Python 🐍 ${{ matrix.python-version }} 33 | uses: actions/setup-python@v5 34 | with: 35 | python-version: ${{ matrix.python-version }} 36 | 37 | - name: Set oldest dependencies 38 | if: matrix.requires == 'oldest' 39 | timeout-minutes: 20 40 | run: | 41 | pip install -e '.[cli]' 42 | python -m lightning_utilities.cli requirements set-oldest 43 | 44 | - name: Complex 💽 caching 45 | uses: ./.github/actions/cache 46 | with: 47 | python-version: ${{ matrix.python-version }} 48 | 49 | - name: Install dependencies 50 | timeout-minutes: 20 51 | run: | 52 | set -e 53 | pip install -e . -U -r requirements/_tests.txt -f $TORCH_URL 54 | pip --version 55 | pip list 56 | 57 | - name: Print 🖨️ dependencies 58 | uses: ./.github/actions/pip-list 59 | 60 | - name: Unittest and coverage 61 | uses: ./.github/actions/unittesting 62 | with: 63 | python-version: ${{ matrix.python-version }} 64 | dirs: "unittests" 65 | pkg-name: "lightning_utilities" 66 | pytest-args: "--timeout=120" 67 | 68 | - name: Upload coverage to Codecov 69 | uses: codecov/codecov-action@v5.4.3 70 | continue-on-error: true 71 | with: 72 | token: ${{ secrets.CODECOV_TOKEN }} 73 | file: ./coverage.xml 74 | flags: unittests 75 | env_vars: OS,PYTHON 76 | name: codecov-umbrella 77 | fail_ci_if_error: false 78 | 79 | - name: test CI scripts 80 | working-directory: ./tests 81 | run: python -m pytest scripts --durations=50 --timeout=120 82 | 83 | testing-guardian: 84 | runs-on: ubuntu-latest 85 | needs: pytester 86 | if: always() 87 | steps: 88 | - run: echo "${{ needs.pytester.result }}" 89 | - name: failing... 90 | if: needs.pytester.result == 'failure' 91 | run: exit 1 92 | - name: cancelled or skipped... 93 | if: contains(fromJSON('["cancelled", "skipped"]'), needs.pytester.result) 94 | timeout-minutes: 1 95 | run: sleep 90 96 | -------------------------------------------------------------------------------- /.github/workflows/ci-use-checks.yaml: -------------------------------------------------------------------------------- 1 | name: Apply checks 2 | 3 | on: 4 | push: 5 | branches: [main, "release/*"] 6 | pull_request: 7 | branches: [main, "release/*"] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 11 | cancel-in-progress: ${{ ! (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }} 12 | 13 | jobs: 14 | check-typing: 15 | uses: ./.github/workflows/check-typing.yml 16 | with: 17 | actions-ref: ${{ github.sha }} # use local version 18 | extra-typing: "typing" 19 | 20 | check-precommit: 21 | uses: ./.github/workflows/check-precommit.yml 22 | 23 | check-schema-latest: 24 | uses: ./.github/workflows/check-schema.yml 25 | with: 26 | azure-dir: "" 27 | 28 | check-schema-fixed: 29 | uses: ./.github/workflows/check-schema.yml 30 | with: 31 | actions-ref: ${{ github.sha }} # use local version 32 | azure-dir: "" 33 | 34 | check-schema: 35 | runs-on: ubuntu-latest 36 | # just aggregation of the previous two jobs 37 | needs: ["check-schema-latest", "check-schema-fixed"] 38 | steps: 39 | - run: echo "done" 40 | 41 | check-package: 42 | uses: ./.github/workflows/check-package.yml 43 | with: 44 | actions-ref: ${{ github.sha }} # use local version 45 | artifact-name: dist-packages-${{ github.sha }} 46 | import-name: "lightning_utilities" 47 | build-matrix: | 48 | { 49 | "os": ["ubuntu-22.04"], 50 | "python-version": ["3.10"] 51 | } 52 | testing-matrix: | 53 | { 54 | "os": ["ubuntu-22.04", "macos-14", "windows-2022"], 55 | "python-version": ["3.9", "3.13"] 56 | } 57 | 58 | check-package-extras: 59 | uses: ./.github/workflows/check-package.yml 60 | with: 61 | actions-ref: ${{ github.sha }} # use local version 62 | artifact-name: dist-packages-extras-${{ github.sha }} 63 | import-name: "lightning_utilities" 64 | install-extras: "[cli]" 65 | # todo: when we have a module with dependence on extra, replace it 66 | # tried to import `lightning_utilities.cli.__main__` but told me it does not exits 67 | custom-import-code: "import jsonargparse" 68 | build-matrix: | 69 | { 70 | "os": ["ubuntu-latest", "macos-latest"], 71 | "python-version": ["3.10"] 72 | } 73 | testing-matrix: | 74 | { 75 | "os": ["ubuntu-latest", "macos-latest", "windows-latest"], 76 | "python-version": ["3.10"] 77 | } 78 | 79 | check-docs: 80 | uses: ./.github/workflows/check-docs.yml 81 | with: 82 | actions-ref: ${{ github.sha }} # use local version 83 | requirements-file: "requirements/_docs.txt" 84 | install-tex: true 85 | 86 | check-md-links-default: 87 | uses: ./.github/workflows/check-md-links.yml 88 | with: 89 | base-branch: main 90 | 91 | check-md-links-w-config: 92 | uses: ./.github/workflows/check-md-links.yml 93 | with: 94 | base-branch: main 95 | config-file: ".github/markdown.links.config.json" 96 | -------------------------------------------------------------------------------- /.github/workflows/cleanup-caches.yml: -------------------------------------------------------------------------------- 1 | name: Cleaning caches 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | scripts-ref: 7 | description: "Version of script, normally the same as workflow" 8 | required: true 9 | type: string 10 | #gh-token: 11 | # description: 'PAT which is authorized to delete caches for given repo' 12 | # required: true 13 | # type: string 14 | dry-run: 15 | description: "allow just listing and not delete any, options yes|no" 16 | required: true 17 | type: string 18 | age-days: 19 | description: "setting the age of caches in days to be dropped" 20 | required: true 21 | type: number 22 | default: 7 23 | pattern: 24 | description: "string to grep cache keys with" 25 | required: false 26 | type: string 27 | default: "" 28 | 29 | defaults: 30 | run: 31 | shell: bash 32 | 33 | jobs: 34 | cleanup-caches: 35 | runs-on: ubuntu-latest 36 | timeout-minutes: 15 37 | env: 38 | GH_TOKEN: ${{ github.token }} 39 | AGE_DAYS: ${{ inputs.age-days }} 40 | steps: 41 | - name: Checkout Code 42 | uses: actions/checkout@v4 43 | 44 | - name: Pull reusable 🤖 actions️ 45 | uses: actions/checkout@v4 46 | with: 47 | ref: ${{ inputs.scripts-ref }} 48 | path: .cicd 49 | repository: Lightning-AI/utilities 50 | - name: install requirements 51 | timeout-minutes: 20 52 | run: pip install -U -r ./.cicd/.github/scripts/find-unused-caches.txt 53 | 54 | - name: List and Filer 🔍 caches 55 | run: | 56 | python ./.cicd/.github/scripts/find-unused-caches.py \ 57 | --repository="${{ github.repository }}" --token=${GH_TOKEN} --age_days=${AGE_DAYS} 58 | cat unused-cashes.txt 59 | 60 | - name: Delete 🗑️ caches 61 | if: inputs.dry-run != 'true' 62 | run: | 63 | # Use a while loop to read each line from the file 64 | while read -r line || [ -n "$line" ]; do 65 | echo "$line"; 66 | # delete each cache based on file... 67 | gh api --method DELETE -H "Accept: application/vnd.github+json" /repos/${{ github.repository }}/actions/caches/ $line; 68 | done < "unused-cashes.txt" 69 | -------------------------------------------------------------------------------- /.github/workflows/cron-clear-cache.yml: -------------------------------------------------------------------------------- 1 | name: Clear cache weekly 2 | 3 | on: 4 | schedule: 5 | # on Sunday's night 2am 6 | - cron: "0 2 * * 0" 7 | pull_request: 8 | paths: 9 | - ".github/scripts/find-unused-caches.py" 10 | - ".github/scripts/find-unused-caches.txt" 11 | - ".github/workflows/cleanup-caches.yml" 12 | - ".github/workflows/cron-clear-cache.yml" 13 | workflow_dispatch: 14 | inputs: 15 | pattern: 16 | description: "pattern for cleaning cache" 17 | default: "pip" 18 | required: false 19 | type: string 20 | age-days: 21 | description: "setting the age of caches in days to be dropped" 22 | required: true 23 | type: number 24 | default: 7 25 | 26 | jobs: 27 | drop-unused-caches: 28 | uses: ./.github/workflows/cleanup-caches.yml 29 | with: 30 | scripts-ref: ${{ github.sha }} # use local version 31 | dry-run: ${{ github.event_name == 'pull_request' }} 32 | # ise input if set of default... 33 | pattern: ${{ inputs.pattern || 'pip|conda' }} 34 | age-days: ${{ fromJSON(inputs.age-days) || 2 }} 35 | -------------------------------------------------------------------------------- /.github/workflows/deploy-docs.yml: -------------------------------------------------------------------------------- 1 | name: "Deploy Docs" 2 | on: 3 | push: 4 | branches: [main] 5 | pull_request: 6 | branches: [main] 7 | 8 | defaults: 9 | run: 10 | shell: bash 11 | 12 | jobs: 13 | # https://github.com/marketplace/actions/deploy-to-github-pages 14 | build-docs-deploy: 15 | runs-on: ubuntu-latest 16 | env: 17 | TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html" 18 | steps: 19 | - name: Checkout 🛎️ 20 | uses: actions/checkout@v4 21 | # If you're using actions/checkout@v4 you must set persist-credentials to false in most cases for the deployment to work correctly. 22 | with: 23 | persist-credentials: false 24 | submodules: recursive 25 | - name: Set up Python 🐍 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: "3.10" 29 | cache: "pip" 30 | 31 | # Note: This uses an internal pip API and may not always work 32 | # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow 33 | - name: Cache 💽 pip 34 | uses: actions/cache@v4 35 | with: 36 | path: ~/.cache/pip 37 | key: pip-${{ hashFiles('requirements/*.txt') }} 38 | restore-keys: pip- 39 | 40 | #- name: Install texlive 41 | # timeout-minutes: 20 42 | # run: | 43 | # # install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux 44 | # sudo apt-get update --fix-missing 45 | # sudo apt-get install -y texlive-latex-extra dvipng texlive-pictures 46 | # shell: bash 47 | 48 | - name: Install dependencies 49 | timeout-minutes: 20 50 | run: | 51 | pip --version 52 | pip install -e . -U -q -r requirements/_docs.txt -f ${TORCH_URL} 53 | pip list 54 | shell: bash 55 | 56 | - name: Make Documentation 57 | working-directory: ./docs 58 | run: make html --jobs 2 59 | 60 | - name: Deploy 🚀 61 | uses: JamesIves/github-pages-deploy-action@v4.7.3 62 | if: github.ref == 'refs/heads/main' 63 | with: 64 | token: ${{ secrets.GITHUB_TOKEN }} 65 | branch: gh-pages # The branch the action should deploy to. 66 | folder: docs/build/html # The folder the action should deploy. 67 | clean: true # Automatically remove deleted files from the deploy branch 68 | target-folder: docs # If you'd like to push the contents of the deployment folder into a specific directory 69 | single-commit: true # you'd prefer to have a single commit on the deployment branch instead of full history 70 | -------------------------------------------------------------------------------- /.github/workflows/label-pr.yml: -------------------------------------------------------------------------------- 1 | name: Label Pull Requests 2 | on: [pull_request_target] 3 | 4 | jobs: 5 | triage: 6 | permissions: 7 | contents: read 8 | pull-requests: write 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/labeler@v5 12 | with: 13 | repo-token: "${{ secrets.GITHUB_TOKEN }}" 14 | configuration-path: .github/labeler-config.yml 15 | -------------------------------------------------------------------------------- /.github/workflows/release-pypi.yml: -------------------------------------------------------------------------------- 1 | name: PyPI Release 2 | 3 | # https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | on: # Trigger the workflow on push or pull request, but only for the master branch 5 | push: 6 | branches: [main, "release/*"] 7 | tags: ["v?[0-9]+.[0-9]+.[0-9]+"] 8 | release: 9 | types: [published] 10 | pull_request: 11 | branches: [main] 12 | 13 | defaults: 14 | run: 15 | shell: bash 16 | 17 | jobs: 18 | # based on https://github.com/pypa/gh-action-pypi-publish 19 | build-package: 20 | runs-on: ubuntu-22.04 21 | timeout-minutes: 10 22 | steps: 23 | - name: Checkout 🛎️ 24 | uses: actions/checkout@v4 25 | - name: Set up Python 🐍 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: "3.10" 29 | - name: Prepare build env. 30 | run: pip install -r ./requirements/gha-package.txt 31 | - name: Create 📦 package 32 | uses: ./.github/actions/pkg-create 33 | - name: Upload 📤 packages 34 | uses: actions/upload-artifact@v4 35 | with: 36 | name: pypi-packages-${{ github.sha }} 37 | path: dist 38 | 39 | upload-package: 40 | needs: build-package 41 | if: github.event_name == 'release' 42 | timeout-minutes: 5 43 | runs-on: ubuntu-latest 44 | steps: 45 | - name: Checkout 🛎️ 46 | uses: actions/checkout@v4 47 | - name: Download 📥 artifact 48 | uses: actions/download-artifact@v4 49 | with: 50 | name: pypi-packages-${{ github.sha }} 51 | path: dist 52 | - name: local 🗃️ files 53 | run: ls -lh dist/ 54 | 55 | - name: Upload to release 56 | uses: AButler/upload-release-assets@v3.0 57 | with: 58 | files: "dist/*" 59 | repo-token: ${{ secrets.GITHUB_TOKEN }} 60 | 61 | publish-package: 62 | needs: build-package 63 | if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags') 64 | runs-on: ubuntu-latest 65 | timeout-minutes: 5 66 | steps: 67 | - name: Checkout 🛎️ 68 | uses: actions/checkout@v4 69 | with: 70 | submodules: recursive 71 | - name: Download 📥 artifact 72 | uses: actions/download-artifact@v4 73 | with: 74 | name: pypi-packages-${{ github.sha }} 75 | path: dist 76 | - name: local 🗃️ files 77 | run: ls -lh dist/ 78 | 79 | # We do this, since failures on test.pypi aren't that bad 80 | - name: Publish to Test PyPI 81 | uses: pypa/gh-action-pypi-publish@v1.12.4 82 | with: 83 | user: __token__ 84 | password: ${{ secrets.test_pypi_password }} 85 | repository-url: https://test.pypi.org/legacy/ 86 | verbose: true 87 | 88 | - name: Publish distribution 📦 to PyPI 89 | uses: pypa/gh-action-pypi-publish@v1.12.4 90 | with: 91 | user: __token__ 92 | password: ${{ secrets.pypi_password }} 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Sphinx documentation 58 | docs/_build/ 59 | docs/source/fetched-s3-assets 60 | docs/source/api/ 61 | docs/source/*.md 62 | 63 | # PyBuilder 64 | target/ 65 | 66 | # Jupyter Notebook 67 | .ipynb_checkpoints 68 | 69 | # IPython 70 | profile_default/ 71 | ipython_config.py 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 77 | __pypackages__/ 78 | 79 | # Celery stuff 80 | celerybeat-schedule 81 | celerybeat.pid 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | .dmypy.json 108 | dmypy.json 109 | 110 | # Pyre type checker 111 | .pyre/ 112 | 113 | # PyCharm 114 | .idea/ 115 | 116 | # Lightning logs 117 | lightning_logs 118 | *.gz 119 | .DS_Store 120 | .*_submit.py 121 | 122 | Formatting 123 | .ruff_cache/ 124 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | ci: 5 | autofix_prs: true 6 | autoupdate_commit_msg: "[pre-commit.ci] pre-commit suggestions" 7 | autoupdate_schedule: monthly 8 | # submodules: true 9 | 10 | repos: 11 | - repo: https://github.com/pre-commit/pre-commit-hooks 12 | rev: v5.0.0 13 | hooks: 14 | - id: end-of-file-fixer 15 | - id: trailing-whitespace 16 | - id: check-case-conflict 17 | - id: check-yaml 18 | - id: check-toml 19 | - id: check-json 20 | - id: check-added-large-files 21 | - id: check-docstring-first 22 | - id: detect-private-key 23 | 24 | - repo: https://github.com/codespell-project/codespell 25 | rev: v2.4.1 26 | hooks: 27 | - id: codespell 28 | additional_dependencies: [tomli] 29 | #args: ["--write-changes"] # uncomment if you want to get automatic fixing 30 | 31 | - repo: https://github.com/PyCQA/docformatter 32 | rev: v1.7.7 33 | hooks: 34 | - id: docformatter 35 | additional_dependencies: [tomli] 36 | args: ["--in-place"] 37 | 38 | - repo: https://github.com/executablebooks/mdformat 39 | rev: 0.7.22 40 | hooks: 41 | - id: mdformat 42 | additional_dependencies: 43 | - mdformat-gfm 44 | - mdformat-black 45 | - mdformat_frontmatter 46 | args: ["--number"] 47 | exclude: CHANGELOG.md 48 | 49 | - repo: https://github.com/JoC0de/pre-commit-prettier 50 | rev: v3.5.3 51 | hooks: 52 | - id: prettier 53 | files: \.(json|yml|yaml|toml) 54 | # https://prettier.io/docs/en/options.html#print-width 55 | args: ["--print-width=120"] 56 | 57 | - repo: https://github.com/astral-sh/ruff-pre-commit 58 | rev: v0.11.12 59 | hooks: 60 | - id: ruff 61 | args: ["--fix"] 62 | - id: ruff-format 63 | - id: ruff 64 | 65 | - repo: https://github.com/sphinx-contrib/sphinx-lint 66 | rev: v1.0.0 67 | hooks: 68 | - id: sphinx-lint 69 | 70 | - repo: https://github.com/tox-dev/pyproject-fmt 71 | rev: v2.6.0 72 | hooks: 73 | - id: pyproject-fmt 74 | additional_dependencies: [tox] 75 | - repo: https://github.com/abravalheri/validate-pyproject 76 | rev: v0.24.1 77 | hooks: 78 | - id: validate-pyproject 79 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2018-2021 William Falcon 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include CHANGELOG.md 2 | recursive-include src py.typed 3 | 4 | graft requirements 5 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test clean docs 2 | 3 | # to imitate SLURM set only single node 4 | export SLURM_LOCALID=0 5 | # assume you have installed need packages 6 | export SPHINX_MOCK_REQUIREMENTS=0 7 | 8 | test: 9 | pip install -q -r requirements/cli.txt -r requirements/_tests.txt 10 | 11 | # use this to run tests 12 | rm -rf _ckpt_* 13 | rm -rf ./lightning_logs 14 | python -m coverage run --source src/lightning_utilities -m pytest src/lightning_utilities tests -v 15 | python -m coverage report 16 | 17 | # specific file 18 | # python -m coverage run --source src/lightning_utilities -m pytest --flake8 --durations=0 -v -k 19 | 20 | docs: clean 21 | pip install -e . -q -r requirements/_docs.txt 22 | cd docs && $(MAKE) html 23 | 24 | clean: 25 | # clean all temp runs 26 | rm -rf .mypy_cache 27 | rm -rf .pytest_cache 28 | rm -rf .ruff_cache 29 | rm -rf build 30 | rm -rf dist 31 | rm -rf src/*.egg-info 32 | rm -rf ./docs/build 33 | rm -rf ./docs/source/**/generated 34 | rm -rf ./docs/source/api 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lightning Utilities 2 | 3 | [![PyPI Status](https://badge.fury.io/py/lightning-utilities.svg)](https://badge.fury.io/py/lightning-utilities) 4 | [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/utilities/blob/master/LICENSE) 5 | [![PyPI - Downloads](https://img.shields.io/pypi/dm/lightning-utilities)](https://pepy.tech/project/lightning-utilities) 6 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/lightning-utilities)](https://pypi.org/project/lightning-utilities/) 7 | 8 | [![UnitTests](https://github.com/Lightning-AI/utilities/actions/workflows/ci-testing.yml/badge.svg?event=push)](https://github.com/Lightning-AI/utilities/actions/workflows/ci-testing.yml) 9 | [![Apply checks](https://github.com/Lightning-AI/utilities/actions/workflows/ci-use-checks.yaml/badge.svg?event=push)](https://github.com/Lightning-AI/utilities/actions/workflows/ci-use-checks.yaml) 10 | [![Docs Status](https://readthedocs.org/projects/lit-utilities/badge/?version=latest)](https://lit-utilities.readthedocs.io/en/latest/?badge=latest) 11 | [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Lightning-AI/utilities/main.svg)](https://results.pre-commit.ci/latest/github/Lightning-AI/utilities/main) 12 | 13 | __This repository covers the following use-cases:__ 14 | 15 | 1. _Reusable GitHub workflows_ 16 | 2. _Shared GitHub actions_ 17 | 3. _General Python utilities in `lightning_utilities.core`_ 18 | 4. _CLI `python -m lightning_utilities.cli --help`_ 19 | 20 | ## 1. Reusable workflows 21 | 22 | __Usage:__ 23 | 24 | ```yaml 25 | name: Check schema 26 | 27 | on: [push] 28 | 29 | jobs: 30 | 31 | check-schema: 32 | uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.5.0 33 | with: 34 | azure-dir: "" # skip Azure check 35 | 36 | check-code: 37 | uses: Lightning-AI/utilities/.github/workflows/check-code.yml@main 38 | with: 39 | actions-ref: main # normally you shall use the same version as the workflow 40 | ``` 41 | 42 | See usage of other workflows in [.github/workflows/ci-use-checks.yaml](https://github.com/Lightning-AI/utilities/tree/main/.github/workflows/ci-use-checks.yaml). 43 | 44 | ## 2. Reusable composite actions 45 | 46 | See available composite actions [.github/actions/](https://github.com/Lightning-AI/utilities/tree/main/.github/actions). 47 | 48 | __Usage:__ 49 | 50 | ```yaml 51 | name: Do something with cache 52 | 53 | on: [push] 54 | 55 | jobs: 56 | pytest: 57 | runs-on: ubuntu-24.04 58 | steps: 59 | - uses: actions/checkout@v3 60 | - uses: actions/setup-python@v4 61 | with: 62 | python-version: 3.9 63 | - uses: Lightning-AI/utilities/.github/actions/cache 64 | with: 65 | python-version: 3.9 66 | requires: oldest # or latest 67 | ``` 68 | 69 | ## 3. General Python utilities `lightning_utilities.core` 70 | 71 |
72 | Installation 73 | From source: 74 | 75 | ```bash 76 | pip install https://github.com/Lightning-AI/utilities/archive/refs/heads/main.zip 77 | ``` 78 | 79 | From pypi: 80 | 81 | ```bash 82 | pip install lightning_utilities 83 | ``` 84 | 85 |
86 | 87 | __Usage:__ 88 | 89 | Example for optional imports: 90 | 91 | ```python 92 | from lightning_utilities.core.imports import module_available 93 | 94 | if module_available("some_package.something"): 95 | from some_package import something 96 | ``` 97 | 98 | ## 4. CLI `lightning_utilities.cli` 99 | 100 | The package provides common CLI commands. 101 | 102 |
103 | Installation 104 | 105 | From pypi: 106 | 107 | ```bash 108 | pip install lightning_utilities[cli] 109 | ``` 110 | 111 |
112 | 113 | __Usage:__ 114 | 115 | ```bash 116 | python -m lightning_utilities.cli [group] [command] 117 | ``` 118 | 119 |
120 | Example for setting min versions 121 | 122 | ```console 123 | $ cat requirements/test.txt 124 | coverage>=5.0 125 | codecov>=2.1 126 | pytest>=6.0 127 | pytest-cov 128 | pytest-timeout 129 | $ python -m lightning_utilities.cli requirements set-oldest 130 | $ cat requirements/test.txt 131 | coverage==5.0 132 | codecov==2.1 133 | pytest==6.0 134 | pytest-cov 135 | pytest-timeout 136 | ``` 137 | 138 |
139 | -------------------------------------------------------------------------------- /docs/.build_docs.sh: -------------------------------------------------------------------------------- 1 | make clean 2 | make html --debug --jobs $(nproc) 3 | -------------------------------------------------------------------------------- /docs/.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml in docs/ ~ need to be re-defined in RTFD settings UI 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.10" 13 | # You can also specify other tool versions: 14 | # nodejs: "20" 15 | 16 | # Build documentation in the docs/ directory with Sphinx 17 | sphinx: 18 | configuration: docs/source/conf.py 19 | # Fail on all warnings to avoid broken references 20 | fail_on_warning: true 21 | 22 | # Optionally build your docs in additional formats such as PDF and ePub 23 | formats: 24 | - htmlzip 25 | - pdf 26 | 27 | # Optionally set the version of Python and requirements required to build your docs 28 | python: 29 | install: 30 | - requirements: requirements/_docs.txt 31 | - method: pip 32 | path: . 33 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = -T -W 6 | SPHINXBUILD = python $(shell which sphinx-build) 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/copybutton.js: -------------------------------------------------------------------------------- 1 | /* Copied from the official Python docs: https://docs.python.org/3/_static/copybutton.js */ 2 | $(document).ready(function () { 3 | /* Add a [>>>] button on the top-right corner of code samples to hide 4 | * the >>> and ... prompts and the output and thus make the code 5 | * copyable. */ 6 | var div = $( 7 | ".highlight-python .highlight," + 8 | ".highlight-python3 .highlight," + 9 | ".highlight-pycon .highlight," + 10 | ".highlight-default .highlight", 11 | ); 12 | var pre = div.find("pre"); 13 | 14 | // get the styles from the current theme 15 | pre.parent().parent().css("position", "relative"); 16 | var hide_text = "Hide the prompts and output"; 17 | var show_text = "Show the prompts and output"; 18 | var border_width = pre.css("border-top-width"); 19 | var border_style = pre.css("border-top-style"); 20 | var border_color = pre.css("border-top-color"); 21 | var button_styles = { 22 | cursor: "pointer", 23 | position: "absolute", 24 | top: "0", 25 | right: "0", 26 | "border-color": border_color, 27 | "border-style": border_style, 28 | "border-width": border_width, 29 | color: border_color, 30 | "text-size": "75%", 31 | "font-family": "monospace", 32 | "padding-left": "0.2em", 33 | "padding-right": "0.2em", 34 | "border-radius": "0 3px 0 0", 35 | }; 36 | 37 | // create and add the button to all the code blocks that contain >>> 38 | div.each(function (index) { 39 | var jthis = $(this); 40 | if (jthis.find(".gp").length > 0) { 41 | var button = $('>>>'); 42 | button.css(button_styles); 43 | button.attr("title", hide_text); 44 | button.data("hidden", "false"); 45 | jthis.prepend(button); 46 | } 47 | // tracebacks (.gt) contain bare text elements that need to be 48 | // wrapped in a span to work with .nextUntil() (see later) 49 | jthis 50 | .find("pre:has(.gt)") 51 | .contents() 52 | .filter(function () { 53 | return this.nodeType == 3 && this.data.trim().length > 0; 54 | }) 55 | .wrap(""); 56 | }); 57 | 58 | // define the behavior of the button when it's clicked 59 | $(".copybutton").click(function (e) { 60 | e.preventDefault(); 61 | var button = $(this); 62 | if (button.data("hidden") === "false") { 63 | // hide the code output 64 | button.parent().find(".go, .gp, .gt").hide(); 65 | button.next("pre").find(".gt").nextUntil(".gp, .go").css("visibility", "hidden"); 66 | button.css("text-decoration", "line-through"); 67 | button.attr("title", show_text); 68 | button.data("hidden", "true"); 69 | } else { 70 | // show the code output 71 | button.parent().find(".go, .gp, .gt").show(); 72 | button.next("pre").find(".gt").nextUntil(".gp, .go").css("visibility", "visible"); 73 | button.css("text-decoration", "none"); 74 | button.attr("title", hide_text); 75 | button.data("hidden", "false"); 76 | } 77 | }); 78 | }); 79 | -------------------------------------------------------------------------------- /docs/source/_static/images/icon.svg: -------------------------------------------------------------------------------- 1 | 2 | 17 | 19 | 20 | 22 | image/svg+xml 23 | 25 | 26 | 27 | 28 | 29 | 31 | 51 | 56 | 62 | 63 | -------------------------------------------------------------------------------- /docs/source/_static/images/logo-large.svg: -------------------------------------------------------------------------------- 1 | 2 | 17 | 19 | 20 | 22 | image/svg+xml 23 | 25 | 26 | 27 | 28 | 30 | 50 | 55 | 61 | 62 | -------------------------------------------------------------------------------- /docs/source/_static/images/logo-small.svg: -------------------------------------------------------------------------------- 1 | 2 | 17 | 19 | 20 | 22 | image/svg+xml 23 | 25 | 26 | 27 | 28 | 29 | 31 | 51 | 56 | 62 | 63 | -------------------------------------------------------------------------------- /docs/source/_static/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/utilities/2658ff8b341fc0f30e9d9e9042a9d6c0e4de49a2/docs/source/_static/images/logo.png -------------------------------------------------------------------------------- /docs/source/_static/images/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 17 | 19 | 42 | 44 | 45 | 47 | image/svg+xml 48 | 50 | 51 | 52 | 53 | 54 | 59 | 63 | 64 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /docs/source/_templates/theme_variables.jinja: -------------------------------------------------------------------------------- 1 | {%- set external_urls = { 2 | 'github': 'https://github.com/Lightning-AI/utilities', 3 | 'github_issues': 'https://github.com/Lightning-AI/utilities/issues', 4 | 'contributing': 'https://github.com/Lightning-AI/lightning/blob/master/CONTRIBUTING.md', 5 | 'governance': 'https://github.com/Lightning-AI/lightning/blob/master/governance.md', 6 | 'docs': 'https://dev-toolbox.rtfd.io/en/latest', 7 | 'twitter': 'https://twitter.com/PyTorchLightnin', 8 | 'discuss': 'https://pytorch-lightning.slack.com', 9 | 'home': 'https://lightning-tools.rtfd.io/en/latest/', 10 | 'get_started': 'https://lightning-tools.readthedocs.io/en/latest/introduction_guide.html', 11 | 'features': 'https://lightning-tools.rtfd.io/en/latest/', 12 | 'blog': 'https://www.pytorchlightning.ai/blog', 13 | 'support': 'https://lightning-tools.rtfd.io/en/latest/', 14 | } 15 | -%} 16 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. LightningAI-DevToolbox documentation master file, created by 2 | sphinx-quickstart on Wed Mar 25 21:34:07 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Lightning-DevToolbox documentation 7 | ================================== 8 | 9 | .. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/Lightning.gif 10 | :alt: What is Lightning gif. 11 | :width: 80 % 12 | 13 | .. toctree:: 14 | :maxdepth: 1 15 | :name: content 16 | :caption: Overview 17 | 18 | Utilities readme 19 | 20 | 21 | Indices and tables 22 | ================== 23 | 24 | * :ref:`genindex` 25 | * :ref:`modindex` 26 | * :ref:`search` 27 | -------------------------------------------------------------------------------- /docs/source/test-page.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | Testing page 4 | ============ 5 | 6 | This is some page serving exclusively for testing purposes. 7 | 8 | Link to scikit-learn stable documentation: https://scikit-learn.org/stable/index.html 9 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools", 4 | "wheel", 5 | ] 6 | 7 | [tool.ruff] 8 | target-version = "py39" 9 | 10 | line-length = 120 11 | format.preview = true 12 | lint.select = [ 13 | "D", # see: https://pypi.org/project/pydocstyle 14 | "E", 15 | "F", # see: https://pypi.org/project/pyflakes 16 | "I", #see: https://pypi.org/project/isort/ 17 | "N", # see: https://pypi.org/project/pep8-naming 18 | "RUF100", # see: https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf 19 | "S", # see: https://pypi.org/project/flake8-bandit 20 | "UP", # see: https://docs.astral.sh/ruff/rules/#pyupgrade-up 21 | "W", # see: https://pypi.org/project/pycodestyle 22 | ] 23 | lint.extend-select = [ 24 | "A", # see: https://pypi.org/project/flake8-builtins 25 | "ANN", # see: https://pypi.org/project/flake8-annotations 26 | "B", # see: https://pypi.org/project/flake8-bugbear 27 | "C4", # see: https://pypi.org/project/flake8-comprehensions 28 | "EXE", # see: https://pypi.org/project/flake8-executable 29 | "ISC", # see: https://pypi.org/project/flake8-implicit-str-concat 30 | "PIE", # see: https://pypi.org/project/flake8-pie 31 | "PLE", # see: https://pypi.org/project/pylint/ 32 | "PT", # see: https://pypi.org/project/flake8-pytest-style 33 | "Q", # see: https://pypi.org/project/flake8-quotes 34 | "RET", # see: https://pypi.org/project/flake8-return 35 | "RUF", # Ruff-specific rules 36 | "SIM", # see: https://pypi.org/project/flake8-simplify 37 | "T10", # see: https://pypi.org/project/flake8-debugger 38 | "TID", # see: https://pypi.org/project/flake8-tidy-imports/ 39 | "YTT", # see: https://pypi.org/project/flake8-2020 40 | ] 41 | lint.ignore = [ 42 | "E731", 43 | ] 44 | lint.per-file-ignores."__about__.py" = [ 45 | "D100", 46 | ] 47 | lint.per-file-ignores."__init__.py" = [ 48 | "D100", 49 | ] 50 | lint.per-file-ignores."docs/source/conf.py" = [ 51 | "A001", 52 | "ANN001", 53 | "ANN201", 54 | "D100", 55 | "D103", 56 | ] 57 | lint.per-file-ignores."setup.py" = [ 58 | "ANN202", 59 | "D100", 60 | "SIM115", 61 | ] 62 | lint.per-file-ignores."src/**" = [ 63 | "ANN101", # Missing type annotation for `self` in method 64 | "ANN102", # Missing type annotation for `cls` in classmethod 65 | "ANN401", # Dynamically typed expressions (typing.Any) 66 | "B905", # `zip()` without an explicit `strict=` parameter 67 | "D100", # Missing docstring in public module 68 | "D107", # Missing docstring in `__init__` 69 | ] 70 | lint.per-file-ignores."tests/**" = [ 71 | "ANN001", # Missing type annotation for function argument 72 | "ANN101", # Missing type annotation for `self` in method 73 | "ANN201", # Missing return type annotation for public function 74 | "ANN202", # Missing return type annotation for private function 75 | "ANN204", # Missing return type annotation for special method 76 | "ANN401", # Dynamically typed expressions (typing.Any) 77 | "B028", # No explicit `stacklevel` keyword argument found 78 | "B905", # `zip()` without an explicit `strict=` parameter 79 | "D100", # Missing docstring in public module 80 | "D101", # Missing docstring in public class 81 | "D102", # Missing docstring in public method 82 | "D103", # Missing docstring in public function 83 | "D104", # Missing docstring in public package 84 | "D105", # Missing docstring in magic method 85 | "D107", # Missing docstring in `__init__` 86 | "S101", # Use of `assert` detected 87 | "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes 88 | ] 89 | lint.mccabe.max-complexity = 10 90 | # Use Google-style docstrings. 91 | lint.pydocstyle.convention = "google" 92 | 93 | [tool.codespell] 94 | #skip = '*.py' 95 | quiet-level = 3 96 | # comma separated list of words; waiting for: 97 | # https://github.com/codespell-project/codespell/issues/2839#issuecomment-1731601603 98 | # also adding links until they ignored by its: nature 99 | # https://github.com/codespell-project/codespell/issues/2243#issuecomment-1732019960 100 | #ignore-words-list = "" 101 | 102 | [tool.docformatter] 103 | recursive = true 104 | # this need to be shorter as some docstings are r"""... 105 | wrap-summaries = 119 106 | wrap-descriptions = 120 107 | blank = true 108 | 109 | [tool.check-manifest] 110 | ignore = [ 111 | "*.yml", 112 | ".github", 113 | ".github/*", 114 | ] 115 | 116 | [tool.pytest.ini_options] 117 | norecursedirs = [ 118 | ".git", 119 | ".github", 120 | "dist", 121 | "build", 122 | "docs", 123 | ] 124 | addopts = [ 125 | "--strict-markers", 126 | "--doctest-modules", 127 | "--durations=25", 128 | "--color=yes", 129 | "--disable-pytest-warnings", 130 | ] 131 | markers = [ 132 | "online: run tests that require internet connection", 133 | ] 134 | filterwarnings = [ 135 | "error::FutureWarning", 136 | ] # todo: "error::DeprecationWarning" 137 | xfail_strict = true 138 | 139 | [tool.coverage.report] 140 | exclude_lines = [ 141 | "pragma: no cover", 142 | "pass", 143 | ] 144 | #[tool.coverage.run] 145 | #parallel = true 146 | #concurrency = ['thread'] 147 | #relative_files = true 148 | 149 | [tool.mypy] 150 | files = [ 151 | "src/lightning_utilities", 152 | ] 153 | disallow_untyped_defs = true 154 | ignore_missing_imports = true 155 | -------------------------------------------------------------------------------- /requirements/_docs.txt: -------------------------------------------------------------------------------- 1 | sphinx >=5.0,<6.0 2 | myst-parser >=1.0.0, <2.0.0 3 | nbsphinx >=0.8.5 4 | ipython[notebook] 5 | pandoc >=1.0 6 | docutils >=0.16 7 | # https://github.com/jupyterlab/jupyterlab_pygments/issues/5 8 | pygments >=2.4.1 9 | #sphinxcontrib-fulltoc >=1.0 10 | #sphinxcontrib-mockautodoc 11 | 12 | pt-lightning-sphinx-theme @ https://github.com/Lightning-AI/lightning_sphinx_theme/archive/master.zip 13 | sphinx-autodoc-typehints >=1.0 14 | sphinx-paramlinks >=0.5.1 15 | sphinx-togglebutton >=0.2 16 | sphinx-copybutton >=0.3 17 | -------------------------------------------------------------------------------- /requirements/_tests.txt: -------------------------------------------------------------------------------- 1 | coverage ==7.2.7; python_version < '3.8' 2 | coverage ==7.8.*; python_version >= '3.8' 3 | pytest ==7.4.3; python_version < '3.8' 4 | pytest ==8.3.*; python_version >= '3.8' 5 | pytest-cov ==6.1.1 6 | pytest-timeout ==2.4.0 7 | -------------------------------------------------------------------------------- /requirements/cli.txt: -------------------------------------------------------------------------------- 1 | jsonargparse[signatures] >=4.38.0 2 | -------------------------------------------------------------------------------- /requirements/core.txt: -------------------------------------------------------------------------------- 1 | importlib-metadata >=4.0.0; python_version < '3.8' 2 | packaging >=17.1 3 | setuptools 4 | typing_extensions 5 | -------------------------------------------------------------------------------- /requirements/docs.txt: -------------------------------------------------------------------------------- 1 | requests >=2.0.0 2 | -------------------------------------------------------------------------------- /requirements/gha-package.txt: -------------------------------------------------------------------------------- 1 | twine ==6.1.* 2 | setuptools ==80.9.* 3 | wheel ==0.45.* 4 | build ==1.2.* 5 | importlib_metadata ==8.7.* 6 | packaging ==25.0 7 | -------------------------------------------------------------------------------- /requirements/gha-schema.txt: -------------------------------------------------------------------------------- 1 | check-jsonschema ==0.33.* 2 | -------------------------------------------------------------------------------- /requirements/typing.txt: -------------------------------------------------------------------------------- 1 | mypy>=1.0.0 2 | 3 | types-setuptools 4 | -------------------------------------------------------------------------------- /scripts/adjust-torch-versions.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # http://www.apache.org/licenses/LICENSE-2.0 3 | # 4 | """Adjusting version across PTorch ecosystem.""" 5 | 6 | import logging 7 | import os 8 | import re 9 | import sys 10 | from typing import Optional 11 | 12 | 13 | def _determine_torchaudio(torch_version: str) -> str: 14 | """Determine the torchaudio version based on the torch version. 15 | 16 | >>> _determine_torchaudio("1.9.0") 17 | '0.9.0' 18 | >>> _determine_torchaudio("2.4.1") 19 | '2.4.1' 20 | >>> _determine_torchaudio("1.8.2") 21 | '0.9.1' 22 | 23 | """ 24 | _version_exceptions = { 25 | "2.0.1": "2.0.2", 26 | "2.0.0": "2.0.1", 27 | "1.8.2": "0.9.1", 28 | } 29 | if torch_version in _version_exceptions: 30 | return _version_exceptions[torch_version] 31 | ver_major, ver_minor, ver_bugfix = map(int, torch_version.split(".")) 32 | ta_ver_array = [ver_major, ver_minor, ver_bugfix] 33 | if ver_major == 1: 34 | ta_ver_array[0] = 0 35 | ta_ver_array[2] = ver_bugfix 36 | return ".".join(map(str, ta_ver_array)) 37 | 38 | 39 | def _determine_torchtext(torch_version: str) -> str: 40 | """Determine the torchtext version based on the torch version. 41 | 42 | >>> _determine_torchtext("1.9.0") 43 | '0.10.0' 44 | >>> _determine_torchtext("2.4.1") 45 | '0.18.0' 46 | >>> _determine_torchtext("1.8.2") 47 | '0.9.1' 48 | 49 | """ 50 | _version_exceptions = { 51 | "2.0.1": "0.15.2", 52 | "2.0.0": "0.15.1", 53 | "1.8.2": "0.9.1", 54 | } 55 | if torch_version in _version_exceptions: 56 | return _version_exceptions[torch_version] 57 | ver_major, ver_minor, ver_bugfix = map(int, torch_version.split(".")) 58 | tt_ver_array = [0, 0, 0] 59 | if ver_major == 1: 60 | tt_ver_array[1] = ver_minor + 1 61 | tt_ver_array[2] = ver_bugfix 62 | elif ver_major == 2: 63 | if ver_minor >= 3: 64 | tt_ver_array[1] = 18 65 | else: 66 | tt_ver_array[1] = ver_minor + 15 67 | tt_ver_array[2] = ver_bugfix 68 | else: 69 | raise ValueError(f"Invalid torch version: {torch_version}") 70 | return ".".join(map(str, tt_ver_array)) 71 | 72 | 73 | def _determine_torchvision(torch_version: str) -> str: 74 | """Determine the torchvision version based on the torch version. 75 | 76 | >>> _determine_torchvision("1.9.0") 77 | '0.10.0' 78 | >>> _determine_torchvision("2.4.1") 79 | '0.19.1' 80 | >>> _determine_torchvision("2.0.1") 81 | '0.15.2' 82 | 83 | """ 84 | _version_exceptions = { 85 | "2.0.1": "0.15.2", 86 | "2.0.0": "0.15.1", 87 | "1.10.2": "0.11.3", 88 | "1.10.1": "0.11.2", 89 | "1.10.0": "0.11.1", 90 | "1.8.2": "0.9.1", 91 | } 92 | if torch_version in _version_exceptions: 93 | return _version_exceptions[torch_version] 94 | ver_major, ver_minor, ver_bugfix = map(int, torch_version.split(".")) 95 | tv_ver_array = [0, 0, 0] 96 | if ver_major == 1: 97 | tv_ver_array[1] = ver_minor + 1 98 | elif ver_major == 2: 99 | tv_ver_array[1] = ver_minor + 15 100 | else: 101 | raise ValueError(f"Invalid torch version: {torch_version}") 102 | tv_ver_array[2] = ver_bugfix 103 | return ".".join(map(str, tv_ver_array)) 104 | 105 | 106 | def find_latest(ver: str) -> dict[str, str]: 107 | """Find the latest version. 108 | 109 | >>> from pprint import pprint 110 | >>> pprint(find_latest("2.4.1")) 111 | {'torch': '2.4.1', 112 | 'torchaudio': '2.4.1', 113 | 'torchtext': '0.18.0', 114 | 'torchvision': '0.19.1'} 115 | >>> pprint(find_latest("2.1")) 116 | {'torch': '2.1.0', 117 | 'torchaudio': '2.1.0', 118 | 'torchtext': '0.16.0', 119 | 'torchvision': '0.16.0'} 120 | 121 | """ 122 | # drop all except semantic version 123 | ver = re.search(r"([\.\d]+)", ver).groups()[0] 124 | # in case there remaining dot at the end - e.g "1.9.0.dev20210504" 125 | ver = ver[:-1] if ver[-1] == "." else ver 126 | if not re.match(r"\d+\.\d+\.\d+", ver): 127 | ver += ".0" # add missing bugfix 128 | logging.debug(f"finding ecosystem versions for: {ver}") 129 | 130 | # find first match 131 | return { 132 | "torch": ver, 133 | "torchvision": _determine_torchvision(ver), 134 | "torchtext": _determine_torchtext(ver), 135 | "torchaudio": _determine_torchaudio(ver), 136 | } 137 | 138 | 139 | def adjust(requires: list[str], pytorch_version: Optional[str] = None) -> list[str]: 140 | """Adjust the versions to be paired within pytorch ecosystem. 141 | 142 | >>> from pprint import pprint 143 | >>> pprint(adjust(["torch>=1.9.0", "torchvision>=0.10.0", "torchtext>=0.10.0", "torchaudio>=0.9.0"], "2.1.0")) 144 | ['torch==2.1.0', 145 | 'torchvision==0.16.0', 146 | 'torchtext==0.16.0', 147 | 'torchaudio==2.1.0'] 148 | 149 | """ 150 | if not pytorch_version: 151 | import torch 152 | 153 | pytorch_version = torch.__version__ 154 | if not pytorch_version: 155 | raise ValueError(f"invalid torch: {pytorch_version}") 156 | 157 | requires_ = [] 158 | options = find_latest(pytorch_version) 159 | logging.debug(f"determined ecosystem alignment: {options}") 160 | for req in requires: 161 | req_split = req.strip().split("#", maxsplit=1) 162 | # anything before fst # shall be requirements 163 | req = req_split[0].strip() 164 | # anything after # in the line is comment 165 | comment = "" if len(req_split) < 2 else " #" + req_split[1] 166 | if not req: 167 | # if only comment make it short 168 | requires_.append(comment.strip()) 169 | continue 170 | for lib, version in options.items(): 171 | replace = f"{lib}=={version}" if version else "" 172 | req = re.sub(rf"\b{lib}(?![-_\w]).*", replace, req) 173 | requires_.append(req + comment.rstrip()) 174 | 175 | return requires_ 176 | 177 | 178 | def _offset_print(reqs: list[str], offset: str = "\t|\t") -> str: 179 | """Adding offset to each line for the printing requirements.""" 180 | reqs = [offset + r for r in reqs] 181 | return os.linesep.join(reqs) 182 | 183 | 184 | def main(requirements_path: str, torch_version: Optional[str] = None) -> None: 185 | """The main entry point with mapping to the CLI for positional arguments only.""" 186 | # rU - universal line ending - https://stackoverflow.com/a/2717154/4521646 187 | with open(requirements_path, encoding="utf8") as fopen: 188 | requirements = fopen.readlines() 189 | requirements = adjust(requirements, torch_version) 190 | logging.info( 191 | f"requirements_path='{requirements_path}' with arg torch_version='{torch_version}' >>\n" 192 | f"{_offset_print(requirements)}" 193 | ) 194 | with open(requirements_path, "w", encoding="utf8") as fopen: 195 | fopen.writelines([r + os.linesep for r in requirements]) 196 | 197 | 198 | if __name__ == "__main__": 199 | logging.basicConfig(level=logging.INFO) 200 | try: 201 | from jsonargparse import auto_cli, set_parsing_settings 202 | 203 | set_parsing_settings(parse_optionals_as_positionals=True) 204 | auto_cli(main) 205 | except (ModuleNotFoundError, ImportError): 206 | main(*sys.argv[1:]) 207 | -------------------------------------------------------------------------------- /scripts/inject-selector-script.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # http://www.apache.org/licenses/LICENSE-2.0 3 | # 4 | """Simple script to inject a custom JS script into all HTML pages in given folder. 5 | 6 | Sample usage: 7 | $ python scripts/inject-selector-script.py "/path/to/folder" torchmetrics 8 | 9 | """ 10 | 11 | import logging 12 | import os 13 | import sys 14 | 15 | 16 | def inject_selector_script_into_html_file(file_path: str, script_url: str) -> None: 17 | """Inject a custom JS script into the given HTML file.""" 18 | with open(file_path) as fopen: 19 | html_content = fopen.read() 20 | html_content = html_content.replace( 21 | "", 22 | f'{os.linesep}', 23 | ) 24 | with open(file_path, "w") as fopen: 25 | fopen.write(html_content) 26 | 27 | 28 | def main(folder: str, selector_name: str) -> None: 29 | """Inject a custom JS script into all HTML files in the given folder.""" 30 | # Sample: https://lightning.ai/docs/torchmetrics/version-selector.js 31 | script_url = f"https://lightning.ai/docs/{selector_name}/version-selector.js" 32 | html_files = [ 33 | os.path.join(root, file) for root, _, files in os.walk(folder) for file in files if file.endswith(".html") 34 | ] 35 | for file_path in html_files: 36 | inject_selector_script_into_html_file(file_path, script_url) 37 | 38 | 39 | if __name__ == "__main__": 40 | logging.basicConfig(level=logging.INFO) 41 | try: 42 | from jsonargparse import auto_cli, set_parsing_settings 43 | 44 | set_parsing_settings(parse_optionals_as_positionals=True) 45 | auto_cli(main) 46 | except (ModuleNotFoundError, ImportError): 47 | main(*sys.argv[1:]) 48 | -------------------------------------------------------------------------------- /scripts/run_standalone_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright The Lightning AI team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # THIS FILE ASSUMES IT IS RUN INSIDE THE tests DIRECTORY. 17 | 18 | # Batch size for testing: Determines how many standalone test invocations run in parallel 19 | # It can be set through the env variable NUM_PARALLEL_TESTS and defaults to 5 if not set 20 | test_batch_size="${NUM_PARALLEL_TESTS:-5}" 21 | 22 | # Source directory for coverage runs can be set with CODECOV_SOURCE. 23 | codecov_source="${COVERAGE_SOURCE}" 24 | 25 | # The test directory is passed as the first argument to the script 26 | test_dir=$1 # parse the first argument 27 | 28 | # There is also timeout for the tests. 29 | # It can be set through the env variable TEST_TIMEOUT and defaults to 1200 seconds. 30 | test_timeout="${TEST_TIMEOUT:-1200}" 31 | 32 | # Temporary file to store the collected tests 33 | COLLECTED_TESTS_FILE="collected_tests.txt" 34 | 35 | ls -lh . # show the contents of the directory 36 | 37 | # If codecov_source is set, prepend the coverage command 38 | if [ -n "$codecov_source" ]; then 39 | cli_coverage="-m coverage run --source ${codecov_source} --append" 40 | else # If not, just keep it empty 41 | cli_coverage="" 42 | fi 43 | # Append the common pytest arguments 44 | cli_pytest="-m pytest --no-header -v -s --color=yes --timeout=${test_timeout}" 45 | 46 | # Python arguments for running the tests and optional coverage 47 | printf "\e[35mUsing defaults: ${cli_coverage} ${cli_pytest}\e[0m\n" 48 | 49 | # Get the list of parametrizations. we need to call them separately. the last two lines are removed. 50 | # note: if there's a syntax error, this will fail with some garbled output 51 | python -um pytest ${test_dir} -q --collect-only --pythonwarnings ignore 2>&1 > $COLLECTED_TESTS_FILE 52 | # Early terminate if collection failed (e.g. syntax error) 53 | if [[ $? != 0 ]]; then 54 | cat $COLLECTED_TESTS_FILE 55 | printf "ERROR: test collection failed!\n" 56 | exit 1 57 | fi 58 | 59 | # Initialize empty array 60 | tests=() 61 | 62 | # Read from file line by line 63 | while IFS= read -r line; do 64 | # Only keep lines containing "test_" 65 | if [[ $line == *"test_"* ]]; then 66 | # Extract part after test_dir/ 67 | pruned_line="${line#*${test_dir}/}" 68 | tests+=("${test_dir}/$pruned_line") 69 | fi 70 | done < $COLLECTED_TESTS_FILE 71 | 72 | # Count tests 73 | test_count=${#tests[@]} 74 | 75 | # Display results 76 | printf "\e[34m================================================================================\e[0m\n" 77 | printf "\e[34mCOLLECTED $test_count TESTS:\e[0m\n" 78 | printf "\e[34m~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\e[0m\n" 79 | printf "\e[34m%s\e[0m\n" "${tests[@]}" 80 | printf "\e[34m================================================================================\e[0m\n" 81 | 82 | # if test count is one print warning 83 | if [[ $test_count -eq 1 ]]; then 84 | printf "\e[33mWARNING: only one test found!\e[0m\n" 85 | elif [ $test_count -eq 0 ]; then 86 | printf "\e[31mERROR: no tests found!\e[0m\n" 87 | exit 1 88 | fi 89 | 90 | if [ -n "$codecov_source" ]; then 91 | coverage combine 92 | fi 93 | 94 | status=0 # aggregated script status 95 | report=() # final report 96 | pids=() # array of PID for running tests 97 | test_ids=() # array of indexes of running tests 98 | failed_tests=() # array of failed tests 99 | printf "Running $test_count tests in batches of $test_batch_size:\n" 100 | for i in "${!tests[@]}"; do 101 | test=${tests[$i]} 102 | 103 | cli_test="python " 104 | if [ -n "$codecov_source" ]; then 105 | # append cli_coverage to the test command 106 | cli_test="${cli_test} ${cli_coverage} --data-file=run-${i}.coverage" 107 | fi 108 | # add the pytest cli to the test command 109 | cli_test="${cli_test} ${cli_pytest}" 110 | 111 | printf "\e[95m* Running test $((i+1))/$test_count: $cli_test $test\e[0m\n" 112 | 113 | # execute the test in the background 114 | # redirect to a log file that buffers test output. since the tests will run in the background, 115 | # we cannot let them output to std{out,err} because the outputs would be garbled together 116 | ${cli_test} "$test" &> "parallel_test_output-$i.txt" & 117 | test_ids+=($i) # save the test's id in an array with running tests 118 | pids+=($!) # save the PID in an array with running tests 119 | 120 | # if we reached the batch size, wait for all tests to finish 121 | if (( (($i + 1) % $test_batch_size == 0) || $i == $test_count-1 )); then 122 | printf "Waiting for batch to finish: $(IFS=' '; echo "${pids[@]}")\n" 123 | # wait for running tests 124 | for j in "${!test_ids[@]}"; do 125 | i=${test_ids[$j]} # restore the global test's id 126 | pid=${pids[$j]} # restore the particular PID 127 | test=${tests[$i]} # restore the test name 128 | printf "\e[33m? Waiting for $test @ parallel_test_output-$i.txt (PID: $pid)\e[0m\n" 129 | wait -n $pid 130 | # get the exit status of the test 131 | test_status=$? 132 | # add row to the final report 133 | report+=("Ran $test >> exit:$test_status") 134 | if [[ $test_status != 0 ]]; then 135 | # add the test to the failed tests array 136 | failed_tests+=($i) 137 | # Process exited with a non-zero exit status 138 | status=$test_status 139 | fi 140 | done 141 | printf "Starting over with a new batch...\n" 142 | test_ids=() # reset the test's id array 143 | pids=() # reset the PID array 144 | fi 145 | done 146 | 147 | # print test report with exit code for each test 148 | printf "\e[35m================================================================================\e[0m\n" 149 | for line in "${report[@]}"; do 150 | if [[ "$line" == *"exit:0"* ]]; then 151 | printf "\e[32m%s\e[0m\n" "$line" # Green for lines containing exit:0 152 | else 153 | printf "\e[31m%s\e[0m\n" "$line" # Red for all other lines 154 | fi 155 | done 156 | printf "\e[35m================================================================================\e[0m\n" 157 | 158 | # print failed tests from duped logs 159 | if [[ ${#failed_tests[@]} -gt 0 ]]; then 160 | printf "\e[34mFAILED TESTS:\e[0m\n" 161 | for i in "${failed_tests[@]}"; do 162 | printf "\e[34m================================================================================\e[0m\n" 163 | printf "\e[34m=== ${tests[$i]} ===\e[0m\n" 164 | printf "\e[34m~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\e[0m\n\n" 165 | # show the output of the failed test 166 | cat "parallel_test_output-$i.txt" 167 | printf "\e[34m~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\e[0m\n" 168 | printf "\e[34m================================================================================\e[0m\n" 169 | printf '\n\n\n' 170 | done 171 | else 172 | printf "\e[32mAll tests passed!\e[0m\n" 173 | fi 174 | 175 | # exit with the worse test result 176 | exit $status 177 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import glob 3 | import os 4 | from importlib.util import module_from_spec, spec_from_file_location 5 | 6 | from pkg_resources import parse_requirements 7 | from setuptools import find_packages, setup 8 | 9 | _PATH_ROOT = os.path.realpath(os.path.dirname(__file__)) 10 | _PATH_SOURCE = os.path.join(_PATH_ROOT, "src") 11 | _PATH_REQUIRE = os.path.join(_PATH_ROOT, "requirements") 12 | 13 | 14 | def _load_py_module(fname: str, pkg: str = "lightning_utilities"): 15 | spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_SOURCE, pkg, fname)) 16 | py = module_from_spec(spec) 17 | spec.loader.exec_module(py) 18 | return py 19 | 20 | 21 | about = _load_py_module("__about__.py") 22 | 23 | # load basic requirements 24 | with open(os.path.join(_PATH_REQUIRE, "core.txt")) as fp: 25 | requirements = list(map(str, parse_requirements(fp.readlines()))) 26 | 27 | 28 | # make extras as automated loading 29 | def _requirement_extras(path_req: str = _PATH_REQUIRE) -> dict: 30 | extras = {} 31 | for fpath in glob.glob(os.path.join(path_req, "*.txt")): 32 | fname = os.path.basename(fpath) 33 | if fname.startswith(("_", "gha-")): 34 | continue 35 | if fname in ("core.txt",): 36 | continue 37 | name, _ = os.path.splitext(fname) 38 | with open(fpath) as fp: 39 | reqs = parse_requirements(fp.readlines()) 40 | extras[name] = list(map(str, reqs)) 41 | return extras 42 | 43 | 44 | # loading readme as description 45 | with open(os.path.join(_PATH_ROOT, "README.md")) as fp: 46 | readme = fp.read() 47 | 48 | setup( 49 | name="lightning-utilities", 50 | version=about.__version__, 51 | description=about.__docs__, 52 | author=about.__author__, 53 | author_email=about.__author_email__, 54 | url=about.__homepage__, 55 | download_url="https://github.com/Lightning-AI/utilities", 56 | license=about.__license__, # Should be a license identifier like "MIT" 57 | license_files=["LICENSE"], # Path to your license file 58 | packages=find_packages(where="src"), 59 | package_dir={"": "src"}, 60 | long_description=readme, 61 | long_description_content_type="text/markdown", 62 | include_package_data=True, 63 | zip_safe=False, 64 | keywords=["Utilities", "DevOps", "CI/CD"], 65 | python_requires=">=3.9", 66 | setup_requires=[], 67 | install_requires=requirements, 68 | extras_require=_requirement_extras(), 69 | project_urls={ 70 | "Bug Tracker": "https://github.com/Lightning-AI/utilities/issues", 71 | "Documentation": "https://dev-toolbox.rtfd.io/en/latest/", # TODO: Update domain 72 | "Source Code": "https://github.com/Lightning-AI/utilities", 73 | }, 74 | classifiers=[ 75 | "Environment :: Console", 76 | "Natural Language :: English", 77 | # How mature is this project? Common values are 78 | # 3 - Alpha, 4 - Beta, 5 - Production/Stable 79 | "Development Status :: 3 - Alpha", 80 | # Indicate who your project is intended for 81 | "Intended Audience :: Developers", 82 | # Pick your license as you wish 83 | # 'License :: OSI Approved :: BSD License', 84 | "Operating System :: OS Independent", 85 | "Programming Language :: Python :: 3", 86 | "Programming Language :: Python :: 3.8", 87 | "Programming Language :: Python :: 3.9", 88 | "Programming Language :: Python :: 3.10", 89 | "Programming Language :: Python :: 3.11", 90 | "Programming Language :: Python :: 3.12", 91 | "Programming Language :: Python :: 3.13", 92 | ], 93 | ) 94 | -------------------------------------------------------------------------------- /src/lightning_utilities/__about__.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | __version__ = "0.15.0dev" 4 | __author__ = "Lightning AI et al." 5 | __author_email__ = "pytorch@lightning.ai" 6 | __license__ = "Apache-2.0" 7 | __copyright__ = f"Copyright (c) 2022-{time.strftime('%Y')}, {__author__}." 8 | __homepage__ = "https://github.com/Lightning-AI/utilities" 9 | __docs__ = "Lightning toolbox for across the our ecosystem." 10 | __long_doc__ = """ 11 | This package allows for sharing GitHub workflows, CI/CD assistance actions, and Python utilities across the Lightning 12 | ecosystem - projects. 13 | """ 14 | 15 | __all__ = [ 16 | "__author__", 17 | "__author_email__", 18 | "__copyright__", 19 | "__docs__", 20 | "__homepage__", 21 | "__license__", 22 | "__version__", 23 | ] 24 | -------------------------------------------------------------------------------- /src/lightning_utilities/__init__.py: -------------------------------------------------------------------------------- 1 | """Root package info.""" 2 | 3 | import os 4 | 5 | from lightning_utilities.__about__ import * # noqa: F403 6 | from lightning_utilities.core.apply_func import apply_to_collection 7 | from lightning_utilities.core.enums import StrEnum 8 | from lightning_utilities.core.imports import compare_version, module_available 9 | from lightning_utilities.core.overrides import is_overridden 10 | from lightning_utilities.core.rank_zero import WarningCache 11 | 12 | _PACKAGE_ROOT = os.path.dirname(__file__) 13 | _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) 14 | 15 | 16 | __all__ = [ 17 | "StrEnum", 18 | "WarningCache", 19 | "apply_to_collection", 20 | "compare_version", 21 | "is_overridden", 22 | "module_available", 23 | ] 24 | -------------------------------------------------------------------------------- /src/lightning_utilities/cli/__init__.py: -------------------------------------------------------------------------------- 1 | """CLI root.""" 2 | -------------------------------------------------------------------------------- /src/lightning_utilities/cli/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # http://www.apache.org/licenses/LICENSE-2.0 4 | # 5 | 6 | import lightning_utilities 7 | from lightning_utilities.cli.dependencies import ( 8 | prune_packages_in_requirements, 9 | replace_oldest_version, 10 | replace_package_in_requirements, 11 | ) 12 | 13 | 14 | def _get_version() -> None: 15 | """Prints the version of the lightning_utilities package.""" 16 | print(lightning_utilities.__version__) 17 | 18 | 19 | def main() -> None: 20 | """CLI entry point.""" 21 | from jsonargparse import auto_cli, set_parsing_settings 22 | 23 | set_parsing_settings(parse_optionals_as_positionals=True) 24 | auto_cli({ 25 | "requirements": { 26 | "_help": "Manage requirements files.", 27 | "prune-pkgs": prune_packages_in_requirements, 28 | "set-oldest": replace_oldest_version, 29 | "replace-pkg": replace_package_in_requirements, 30 | }, 31 | "version": _get_version, 32 | }) 33 | 34 | 35 | if __name__ == "__main__": 36 | main() 37 | -------------------------------------------------------------------------------- /src/lightning_utilities/cli/dependencies.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # http://www.apache.org/licenses/LICENSE-2.0 4 | # 5 | import glob 6 | import os.path 7 | import re 8 | from collections.abc import Sequence 9 | from pprint import pprint 10 | from typing import Union 11 | 12 | REQUIREMENT_ROOT = "requirements.txt" 13 | REQUIREMENT_FILES_ALL: list = glob.glob(os.path.join("requirements", "*.txt")) 14 | REQUIREMENT_FILES_ALL += glob.glob(os.path.join("requirements", "**", "*.txt"), recursive=True) 15 | if os.path.isfile(REQUIREMENT_ROOT): 16 | REQUIREMENT_FILES_ALL += [REQUIREMENT_ROOT] 17 | 18 | 19 | def prune_packages_in_requirements( 20 | packages: Union[str, Sequence[str]], req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL 21 | ) -> None: 22 | """Remove some packages from given requirement files.""" 23 | if isinstance(packages, str): 24 | packages = [packages] 25 | if isinstance(req_files, str): 26 | req_files = [req_files] 27 | for req in req_files: 28 | _prune_packages(req, packages) 29 | 30 | 31 | def _prune_packages(req_file: str, packages: Sequence[str]) -> None: 32 | """Remove some packages from given requirement files.""" 33 | with open(req_file) as fp: 34 | lines = fp.readlines() 35 | 36 | if isinstance(packages, str): 37 | packages = [packages] 38 | for pkg in packages: 39 | lines = [ln for ln in lines if not ln.startswith(pkg)] 40 | pprint(lines) 41 | 42 | with open(req_file, "w") as fp: 43 | fp.writelines(lines) 44 | 45 | 46 | def _replace_min(fname: str) -> None: 47 | with open(fname) as fopen: 48 | req = fopen.read().replace(">=", "==") 49 | with open(fname, "w") as fw: 50 | fw.write(req) 51 | 52 | 53 | def replace_oldest_version(req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL) -> None: 54 | """Replace the min package version by fixed one.""" 55 | if isinstance(req_files, str): 56 | req_files = [req_files] 57 | for fname in req_files: 58 | _replace_min(fname) 59 | 60 | 61 | def _replace_package_name(requirements: list[str], old_package: str, new_package: str) -> list[str]: 62 | """Replace one package by another with same version in given requirement file. 63 | 64 | >>> _replace_package_name(["torch>=1.0 # comment", "torchvision>=0.2", "torchtext <0.3"], "torch", "pytorch") 65 | ['pytorch>=1.0 # comment', 'torchvision>=0.2', 'torchtext <0.3'] 66 | 67 | """ 68 | for i, req in enumerate(requirements): 69 | requirements[i] = re.sub(r"^" + re.escape(old_package) + r"(?=[ <=>#]|$)", new_package, req) 70 | return requirements 71 | 72 | 73 | def replace_package_in_requirements( 74 | old_package: str, new_package: str, req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL 75 | ) -> None: 76 | """Replace one package by another with same version in given requirement files.""" 77 | if isinstance(req_files, str): 78 | req_files = [req_files] 79 | for fname in req_files: 80 | with open(fname) as fopen: 81 | reqs = fopen.readlines() 82 | reqs = _replace_package_name(reqs, old_package, new_package) 83 | with open(fname, "w") as fw: 84 | fw.writelines(reqs) 85 | -------------------------------------------------------------------------------- /src/lightning_utilities/core/__init__.py: -------------------------------------------------------------------------------- 1 | """Core utilities.""" 2 | 3 | from lightning_utilities.core.apply_func import apply_to_collection 4 | from lightning_utilities.core.enums import StrEnum 5 | from lightning_utilities.core.imports import compare_version, module_available 6 | from lightning_utilities.core.overrides import is_overridden 7 | from lightning_utilities.core.rank_zero import WarningCache 8 | 9 | __all__ = [ 10 | "StrEnum", 11 | "WarningCache", 12 | "apply_to_collection", 13 | "compare_version", 14 | "is_overridden", 15 | "module_available", 16 | ] 17 | -------------------------------------------------------------------------------- /src/lightning_utilities/core/enums.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # http://www.apache.org/licenses/LICENSE-2.0 4 | # 5 | import warnings 6 | from enum import Enum 7 | from typing import Optional 8 | 9 | from typing_extensions import Literal 10 | 11 | 12 | class StrEnum(str, Enum): 13 | """Type of any enumerator with allowed comparison to string invariant to cases. 14 | 15 | >>> class MySE(StrEnum): 16 | ... t1 = "T-1" 17 | ... t2 = "T-2" 18 | >>> MySE("T-1") == MySE.t1 19 | True 20 | >>> MySE.from_str("t-2", source="value") == MySE.t2 21 | True 22 | >>> MySE.from_str("t-2", source="value") 23 | 24 | >>> MySE.from_str("t-3", source="any") 25 | Traceback (most recent call last): 26 | ... 27 | ValueError: Invalid match: expected one of ['t1', 't2', 'T-1', 'T-2'], but got t-3. 28 | 29 | """ 30 | 31 | @classmethod 32 | def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> "StrEnum": 33 | """Create ``StrEnum`` from a string matching the key or value. 34 | 35 | Args: 36 | value: matching string 37 | source: compare with: 38 | 39 | - ``"key"``: validates only from the enum keys, typical alphanumeric with "_" 40 | - ``"value"``: validates only from the values, could be any string 41 | - ``"any"``: validates with any key or value, but key has priority 42 | 43 | Raises: 44 | ValueError: 45 | if requested string does not match any option based on selected source. 46 | 47 | """ 48 | if source in ("key", "any"): 49 | for enum_key in cls.__members__: 50 | if enum_key.lower() == value.lower(): 51 | return cls[enum_key] 52 | if source in ("value", "any"): 53 | for enum_key, enum_val in cls.__members__.items(): 54 | if enum_val == value: 55 | return cls[enum_key] 56 | raise ValueError(f"Invalid match: expected one of {cls._allowed_matches(source)}, but got {value}.") 57 | 58 | @classmethod 59 | def try_from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> Optional["StrEnum"]: 60 | """Try to create emun and if it does not match any, return `None`.""" 61 | try: 62 | return cls.from_str(value, source) 63 | except ValueError: 64 | warnings.warn( # noqa: B028 65 | UserWarning(f"Invalid string: expected one of {cls._allowed_matches(source)}, but got {value}.") 66 | ) 67 | return None 68 | 69 | @classmethod 70 | def _allowed_matches(cls, source: str) -> list[str]: 71 | keys, vals = [], [] 72 | for enum_key, enum_val in cls.__members__.items(): 73 | keys.append(enum_key) 74 | vals.append(enum_val.value) 75 | if source == "key": 76 | return keys 77 | if source == "value": 78 | return vals 79 | return keys + vals 80 | 81 | def __eq__(self, other: object) -> bool: 82 | """Compare two instances.""" 83 | if isinstance(other, Enum): 84 | other = other.value 85 | return self.value.lower() == str(other).lower() 86 | 87 | def __hash__(self) -> int: 88 | """Return unique hash.""" 89 | # re-enable hashtable, so it can be used as a dict key or in a set 90 | # example: set(LightningEnum) 91 | return hash(self.value.lower()) 92 | -------------------------------------------------------------------------------- /src/lightning_utilities/core/inheritance.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # http://www.apache.org/licenses/LICENSE-2.0 4 | # 5 | from collections.abc import Iterator 6 | 7 | 8 | def get_all_subclasses_iterator(cls: type) -> Iterator[type]: 9 | """Iterate over all subclasses.""" 10 | 11 | def recurse(cl: type) -> Iterator[type]: 12 | for subclass in cl.__subclasses__(): 13 | yield subclass 14 | yield from recurse(subclass) 15 | 16 | yield from recurse(cls) 17 | 18 | 19 | def get_all_subclasses(cls: type) -> set[type]: 20 | """List all subclasses of a class.""" 21 | return set(get_all_subclasses_iterator(cls)) 22 | -------------------------------------------------------------------------------- /src/lightning_utilities/core/overrides.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # http://www.apache.org/licenses/LICENSE-2.0 4 | # 5 | from functools import partial 6 | from unittest.mock import Mock 7 | 8 | 9 | def is_overridden(method_name: str, instance: object, parent: type[object]) -> bool: 10 | """Check if a method of a given object was overwritten.""" 11 | instance_attr = getattr(instance, method_name, None) 12 | if instance_attr is None: 13 | return False 14 | # `functools.wraps()` and `@contextmanager` support 15 | if hasattr(instance_attr, "__wrapped__"): 16 | instance_attr = instance_attr.__wrapped__ 17 | # `Mock(wraps=...)` support 18 | if isinstance(instance_attr, Mock): 19 | # access the wrapped function 20 | instance_attr = instance_attr._mock_wraps 21 | # `partial` support 22 | elif isinstance(instance_attr, partial): 23 | instance_attr = instance_attr.func 24 | if instance_attr is None: 25 | return False 26 | 27 | parent_attr = getattr(parent, method_name, None) 28 | if parent_attr is None: 29 | raise ValueError("The parent should define the method") 30 | # `@contextmanager` support 31 | if hasattr(parent_attr, "__wrapped__"): 32 | parent_attr = parent_attr.__wrapped__ 33 | 34 | return instance_attr.__code__ != parent_attr.__code__ 35 | -------------------------------------------------------------------------------- /src/lightning_utilities/core/rank_zero.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # http://www.apache.org/licenses/LICENSE-2.0 4 | # 5 | """Utilities that can be used for calling functions on a particular rank.""" 6 | 7 | import logging 8 | import warnings 9 | from functools import wraps 10 | from typing import Any, Callable, Optional, TypeVar, Union 11 | 12 | from typing_extensions import ParamSpec, overload 13 | 14 | log = logging.getLogger(__name__) 15 | 16 | T = TypeVar("T") 17 | P = ParamSpec("P") 18 | 19 | 20 | @overload 21 | def rank_zero_only(fn: Callable[P, T]) -> Callable[P, Optional[T]]: ... 22 | 23 | 24 | @overload 25 | def rank_zero_only(fn: Callable[P, T], default: T) -> Callable[P, T]: ... 26 | 27 | 28 | def rank_zero_only(fn: Callable[P, T], default: Optional[T] = None) -> Callable[P, Optional[T]]: 29 | """Wrap a function to call internal function only in rank zero. 30 | 31 | Function that can be used as a decorator to enable a function/method being called only on global rank 0. 32 | 33 | """ 34 | 35 | @wraps(fn) 36 | def wrapped_fn(*args: P.args, **kwargs: P.kwargs) -> Optional[T]: 37 | rank = getattr(rank_zero_only, "rank", None) 38 | if rank is None: 39 | raise RuntimeError("The `rank_zero_only.rank` needs to be set before use") 40 | if rank == 0: 41 | return fn(*args, **kwargs) 42 | return default 43 | 44 | return wrapped_fn 45 | 46 | 47 | def _debug(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None: 48 | kwargs["stacklevel"] = stacklevel 49 | log.debug(*args, **kwargs) 50 | 51 | 52 | @rank_zero_only 53 | def rank_zero_debug(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None: 54 | """Emit debug-level messages only on global rank 0.""" 55 | _debug(*args, stacklevel=stacklevel, **kwargs) 56 | 57 | 58 | def _info(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None: 59 | kwargs["stacklevel"] = stacklevel 60 | log.info(*args, **kwargs) 61 | 62 | 63 | @rank_zero_only 64 | def rank_zero_info(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None: 65 | """Emit info-level messages only on global rank 0.""" 66 | _info(*args, stacklevel=stacklevel, **kwargs) 67 | 68 | 69 | def _warn(message: Union[str, Warning], stacklevel: int = 2, **kwargs: Any) -> None: 70 | warnings.warn(message, stacklevel=stacklevel, **kwargs) 71 | 72 | 73 | @rank_zero_only 74 | def rank_zero_warn(message: Union[str, Warning], stacklevel: int = 4, **kwargs: Any) -> None: 75 | """Emit warn-level messages only on global rank 0.""" 76 | _warn(message, stacklevel=stacklevel, **kwargs) 77 | 78 | 79 | rank_zero_deprecation_category = DeprecationWarning 80 | 81 | 82 | def rank_zero_deprecation(message: Union[str, Warning], stacklevel: int = 5, **kwargs: Any) -> None: 83 | """Emit a deprecation warning only on global rank 0.""" 84 | category = kwargs.pop("category", rank_zero_deprecation_category) 85 | rank_zero_warn(message, stacklevel=stacklevel, category=category, **kwargs) 86 | 87 | 88 | def rank_prefixed_message(message: str, rank: Optional[int]) -> str: 89 | """Add a prefix with the rank to a message.""" 90 | if rank is not None: 91 | # specify the rank of the process being logged 92 | return f"[rank: {rank}] {message}" 93 | return message 94 | 95 | 96 | class WarningCache(set): 97 | """Cache for warnings.""" 98 | 99 | def warn(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None: 100 | """Trigger warning message.""" 101 | if message not in self: 102 | self.add(message) 103 | rank_zero_warn(message, stacklevel=stacklevel, **kwargs) 104 | 105 | def deprecation(self, message: str, stacklevel: int = 6, **kwargs: Any) -> None: 106 | """Trigger deprecation message.""" 107 | if message not in self: 108 | self.add(message) 109 | rank_zero_deprecation(message, stacklevel=stacklevel, **kwargs) 110 | 111 | def info(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None: 112 | """Trigger info message.""" 113 | if message not in self: 114 | self.add(message) 115 | rank_zero_info(message, stacklevel=stacklevel, **kwargs) 116 | -------------------------------------------------------------------------------- /src/lightning_utilities/docs/__init__.py: -------------------------------------------------------------------------------- 1 | """General tools for Docs.""" 2 | 3 | from lightning_utilities.docs.formatting import adjust_linked_external_docs 4 | from lightning_utilities.docs.retriever import fetch_external_assets 5 | 6 | __all__ = ["adjust_linked_external_docs", "fetch_external_assets"] 7 | -------------------------------------------------------------------------------- /src/lightning_utilities/docs/formatting.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # http://www.apache.org/licenses/LICENSE-2.0 3 | # 4 | import glob 5 | import importlib 6 | import inspect 7 | import logging 8 | import os 9 | import re 10 | import sys 11 | from collections.abc import Iterable 12 | from typing import Optional, Union 13 | 14 | 15 | def _transform_changelog(path_in: str, path_out: str) -> None: 16 | """Adjust changelog titles so not to be duplicated. 17 | 18 | Args: 19 | path_in: input MD file 20 | path_out: output also MD file 21 | 22 | """ 23 | with open(path_in) as fp: 24 | chlog_lines = fp.readlines() 25 | # enrich short subsub-titles to be unique 26 | chlog_ver = "" 27 | for i, ln in enumerate(chlog_lines): 28 | if ln.startswith("## "): 29 | chlog_ver = ln[2:].split("-")[0].strip() 30 | elif ln.startswith("### "): 31 | ln = ln.replace("###", f"### {chlog_ver} -") 32 | chlog_lines[i] = ln 33 | with open(path_out, "w") as fp: 34 | fp.writelines(chlog_lines) 35 | 36 | 37 | def _linkcode_resolve( 38 | domain: str, 39 | info: dict, 40 | github_user: str, 41 | github_repo: str, 42 | main_branch: str = "master", 43 | stable_branch: str = "release/stable", 44 | ) -> str: 45 | def find_source() -> tuple[str, int, int]: 46 | # try to find the file and line number, based on code from numpy: 47 | # https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L286 48 | obj = sys.modules[info["module"]] 49 | for part in info["fullname"].split("."): 50 | obj = getattr(obj, part) 51 | fname = str(inspect.getsourcefile(obj)) 52 | # https://github.com/rtfd/readthedocs.org/issues/5735 53 | if any(s in fname for s in ("readthedocs", "rtfd", "checkouts")): 54 | # /home/docs/checkouts/readthedocs.org/user_builds/pytorch_lightning/checkouts/ 55 | # devel/pytorch_lightning/utilities/cls_experiment.py#L26-L176 56 | path_top = os.path.abspath(os.path.join("..", "..", "..")) 57 | fname = str(os.path.relpath(fname, start=path_top)) 58 | else: 59 | # Local build, imitate master 60 | fname = f"{main_branch}/{os.path.relpath(fname, start=os.path.abspath('..'))}" 61 | source, line_start = inspect.getsourcelines(obj) 62 | return fname, line_start, line_start + len(source) - 1 63 | 64 | if domain != "py" or not info["module"]: 65 | return "" 66 | try: 67 | filename = "%s#L%d-L%d" % find_source() # noqa: UP031 68 | except Exception: 69 | filename = info["module"].replace(".", "/") + ".py" 70 | # import subprocess 71 | # tag = subprocess.Popen(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE, 72 | # universal_newlines=True).communicate()[0][:-1] 73 | branch = filename.split("/")[0] 74 | # do mapping from latest tags to master 75 | branch = {"latest": main_branch, "stable": stable_branch}.get(branch, branch) 76 | filename = "/".join([branch] + filename.split("/")[1:]) 77 | return f"https://github.com/{github_user}/{github_repo}/blob/{filename}" 78 | 79 | 80 | def _load_pypi_versions(package_name: str) -> list[str]: 81 | """Load the versions of the package from PyPI. 82 | 83 | >>> _load_pypi_versions("numpy") # doctest: +ELLIPSIS 84 | ['0.9.6', '0.9.8', '1.0', ...] 85 | >>> _load_pypi_versions("scikit-learn") # doctest: +ELLIPSIS 86 | ['0.9', '0.10', '0.11', '0.12', ...] 87 | 88 | """ 89 | from distutils.version import LooseVersion 90 | 91 | import requests 92 | 93 | url = f"https://pypi.org/pypi/{package_name}/json" 94 | data = requests.get(url, timeout=10).json() 95 | versions = data["releases"].keys() 96 | # filter all version which include only numbers and dots 97 | versions = {k for k in versions if re.match(r"^\d+(\.\d+)*$", k)} 98 | return sorted(versions, key=LooseVersion) 99 | 100 | 101 | def _update_link_based_imported_package(link: str, pkg_ver: str, version_digits: Optional[int]) -> str: 102 | """Adjust the linked external docs to be local. 103 | 104 | Args: 105 | link: the source link to be replaced 106 | pkg_ver: the target link to be replaced, if ``{package.version}`` is included it will be replaced accordingly 107 | version_digits: for semantic versioning, how many digits to be considered 108 | 109 | """ 110 | pkg_att = pkg_ver.split(".") 111 | try: 112 | ver = _load_pypi_versions(pkg_att[0])[-1] 113 | except Exception: 114 | # load the package with all additional sub-modules 115 | module = importlib.import_module(".".join(pkg_att[:-1])) 116 | # load the attribute 117 | ver = getattr(module, pkg_att[0]) 118 | # drop any additional context after `+` 119 | ver = ver.split("+")[0] 120 | # crop the version to the number of digits 121 | ver = ".".join(ver.split(".")[:version_digits]) 122 | # replace the version 123 | return link.replace(f"{{{pkg_ver}}}", ver) 124 | 125 | 126 | def adjust_linked_external_docs( 127 | source_link: str, 128 | target_link: str, 129 | browse_folder: Union[str, Iterable[str]], 130 | file_extensions: Iterable[str] = (".rst", ".py"), 131 | version_digits: int = 2, 132 | ) -> None: 133 | r"""Adjust the linked external docs to be local. 134 | 135 | Args: 136 | source_link: the link to be replaced 137 | target_link: the link to be replaced, if ``{package.version}`` is included it will be replaced accordingly 138 | browse_folder: the location of the browsable folder 139 | file_extensions: what kind of files shall be scanned 140 | version_digits: for semantic versioning, how many digits to be considered 141 | 142 | Examples: 143 | >>> adjust_linked_external_docs( 144 | ... "https://numpy.org/doc/stable/", 145 | ... "https://numpy.org/doc/{numpy.__version__}/", 146 | ... "docs/source", 147 | ... ) 148 | 149 | """ 150 | list_files = [] 151 | if isinstance(browse_folder, str): 152 | browse_folder = [browse_folder] 153 | for folder in browse_folder: 154 | for ext in file_extensions: 155 | list_files += glob.glob(os.path.join(folder, "**", f"*{ext}"), recursive=True) 156 | if not list_files: 157 | logging.warning(f'No files were listed in folder "{browse_folder}" and pattern "{file_extensions}"') 158 | return 159 | 160 | # find the expression for package version in {} brackets if any, use re to find it 161 | pkg_ver_all = re.findall(r"{(.+)}", target_link) 162 | for pkg_ver in pkg_ver_all: 163 | target_link = _update_link_based_imported_package(target_link, pkg_ver, version_digits) 164 | 165 | # replace the source link with target link 166 | for fpath in set(list_files): 167 | with open(fpath, encoding="UTF-8") as fopen: 168 | lines = fopen.readlines() 169 | found, skip = False, False 170 | for i, ln in enumerate(lines): 171 | # prevent the replacement its own function calls 172 | if f"{adjust_linked_external_docs.__name__}(" in ln: 173 | skip = True 174 | if not skip and source_link in ln: 175 | # replace the link if any found 176 | lines[i] = ln.replace(source_link, target_link) 177 | # record the found link for later write file 178 | found = True 179 | if skip and ")" in ln: 180 | skip = False 181 | if not found: 182 | continue 183 | logging.debug(f'links adjusting in {fpath}: "{source_link}" -> "{target_link}"') 184 | with open(fpath, "w", encoding="UTF-8") as fw: 185 | fw.writelines(lines) 186 | -------------------------------------------------------------------------------- /src/lightning_utilities/docs/retriever.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # http://www.apache.org/licenses/LICENSE-2.0 3 | # 4 | import glob 5 | import logging 6 | import os 7 | import re 8 | 9 | import requests 10 | 11 | 12 | def _download_file(file_url: str, folder: str) -> str: 13 | """Download a file from URL to a particular folder.""" 14 | fname = os.path.basename(file_url) 15 | file_path = os.path.join(folder, fname) 16 | if os.path.isfile(file_path): 17 | logging.warning(f'given file "{file_path}" already exists and will be overwritten with {file_url}') 18 | # see: https://stackoverflow.com/a/34957875 19 | rq = requests.get(file_url, timeout=10) 20 | with open(file_path, "wb") as outfile: 21 | outfile.write(rq.content) 22 | return fname 23 | 24 | 25 | def _search_all_occurrences(list_files: list[str], pattern: str) -> list[str]: 26 | """Search for all occurrences of specific pattern in a collection of files. 27 | 28 | Args: 29 | list_files: list of files to be scanned 30 | pattern: pattern for search, reg. expression 31 | 32 | """ 33 | collected = [] 34 | for file_path in list_files: 35 | with open(file_path, encoding="UTF-8") as fopem: 36 | body = fopem.read() 37 | found = re.findall(pattern, body) 38 | collected += found 39 | return collected 40 | 41 | 42 | def _replace_remote_with_local(file_path: str, docs_folder: str, pairs_url_path: list[tuple[str, str]]) -> None: 43 | """Replace all URL with local files in a given file. 44 | 45 | Args: 46 | file_path: file for replacement 47 | docs_folder: the location of docs related to the project root 48 | pairs_url_path: pairs of URL and local file path to be swapped 49 | 50 | """ 51 | # drop the default/global path to the docs 52 | relt_path = os.path.dirname(file_path).replace(docs_folder, "") 53 | # filter the path starting with / as not empty folder names 54 | depth = len([p for p in relt_path.split(os.path.sep) if p]) 55 | with open(file_path, encoding="UTF-8") as fopen: 56 | body = fopen.read() 57 | for url, fpath in pairs_url_path: 58 | if depth: 59 | path_up = [".."] * depth 60 | fpath = os.path.join(*path_up, fpath) 61 | body = body.replace(url, fpath) 62 | with open(file_path, "w", encoding="UTF-8") as fw: 63 | fw.write(body) 64 | 65 | 66 | def fetch_external_assets( 67 | docs_folder: str = "docs/source", 68 | assets_folder: str = "fetched-s3-assets", 69 | file_pattern: str = "*.rst", 70 | retrieve_pattern: str = r"https?://[-a-zA-Z0-9_]+\.s3\.[-a-zA-Z0-9()_\\+.\\/=]+", 71 | ) -> None: 72 | """Search all URL in docs, download these files locally and replace online with local version. 73 | 74 | Args: 75 | docs_folder: the location of docs related to the project root 76 | assets_folder: a folder inside ``docs_folder`` to be created and saving online assets 77 | file_pattern: what kind of files shall be scanned 78 | retrieve_pattern: pattern for reg. expression to search URL/S3 resources 79 | 80 | """ 81 | list_files = glob.glob(os.path.join(docs_folder, "**", file_pattern), recursive=True) 82 | if not list_files: 83 | logging.warning(f'no files were listed in folder "{docs_folder}" and pattern "{file_pattern}"') 84 | return 85 | 86 | urls = _search_all_occurrences(list_files, pattern=retrieve_pattern) 87 | if not urls: 88 | logging.info(f"no resources/assets were match in {docs_folder} for {retrieve_pattern}") 89 | return 90 | target_folder = os.path.join(docs_folder, assets_folder) 91 | os.makedirs(target_folder, exist_ok=True) 92 | pairs_url_file = [] 93 | for i, url in enumerate(set(urls)): 94 | logging.info(f" >> downloading ({i}/{len(urls)}): {url}") 95 | fname = _download_file(url, target_folder) 96 | pairs_url_file.append((url, os.path.join(assets_folder, fname))) 97 | 98 | for fpath in list_files: 99 | _replace_remote_with_local(fpath, docs_folder, pairs_url_file) 100 | -------------------------------------------------------------------------------- /src/lightning_utilities/install/__init__.py: -------------------------------------------------------------------------------- 1 | """Generic Installation tools.""" 2 | 3 | from lightning_utilities.install.requirements import Requirement, load_requirements 4 | 5 | __all__ = ["Requirement", "load_requirements"] 6 | -------------------------------------------------------------------------------- /src/lightning_utilities/install/requirements.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # http://www.apache.org/licenses/LICENSE-2.0 3 | # 4 | import re 5 | from collections.abc import Iterable, Iterator 6 | from distutils.version import LooseVersion 7 | from pathlib import Path 8 | from typing import Any, Optional, Union 9 | 10 | from pkg_resources import Requirement, yield_lines # type: ignore[import-untyped] 11 | 12 | 13 | class _RequirementWithComment(Requirement): 14 | strict_string = "# strict" 15 | 16 | def __init__(self, *args: Any, comment: str = "", pip_argument: Optional[str] = None, **kwargs: Any) -> None: 17 | super().__init__(*args, **kwargs) 18 | self.comment = comment 19 | if not (pip_argument is None or pip_argument): # sanity check that it's not an empty str 20 | raise RuntimeError(f"wrong pip argument: {pip_argument}") 21 | self.pip_argument = pip_argument 22 | self.strict = self.strict_string in comment.lower() 23 | 24 | def adjust(self, unfreeze: str) -> str: 25 | """Remove version restrictions unless they are strict. 26 | 27 | >>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# anything").adjust("none") 28 | 'arrow<=1.2.2,>=1.2.0' 29 | >>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# strict").adjust("none") 30 | 'arrow<=1.2.2,>=1.2.0 # strict' 31 | >>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# my name").adjust("all") 32 | 'arrow>=1.2.0' 33 | >>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# strict").adjust("all") 34 | 'arrow<=1.2.2,>=1.2.0 # strict' 35 | >>> _RequirementWithComment("arrow").adjust("all") 36 | 'arrow' 37 | >>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# cool").adjust("major") 38 | 'arrow<2.0,>=1.2.0' 39 | >>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# strict").adjust("major") 40 | 'arrow<=1.2.2,>=1.2.0 # strict' 41 | >>> _RequirementWithComment("arrow>=1.2.0").adjust("major") 42 | 'arrow>=1.2.0' 43 | >>> _RequirementWithComment("arrow").adjust("major") 44 | 'arrow' 45 | 46 | """ 47 | out = str(self) 48 | if self.strict: 49 | return f"{out} {self.strict_string}" 50 | if unfreeze == "major": 51 | for operator, version in self.specs: 52 | if operator in ("<", "<="): 53 | major = LooseVersion(version).version[0] 54 | # replace upper bound with major version increased by one 55 | return out.replace(f"{operator}{version}", f"<{int(major) + 1}.0") 56 | elif unfreeze == "all": 57 | for operator, version in self.specs: 58 | if operator in ("<", "<="): 59 | # drop upper bound 60 | return out.replace(f"{operator}{version},", "") 61 | elif unfreeze != "none": 62 | raise ValueError(f"Unexpected unfreeze: {unfreeze!r} value.") 63 | return out 64 | 65 | 66 | def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_RequirementWithComment]: 67 | r"""Adapted from `pkg_resources.parse_requirements` to include comments. 68 | 69 | >>> txt = ['# ignored', '', 'this # is an', '--piparg', 'example', 'foo # strict', 'thing', '-r different/file.txt'] 70 | >>> [r.adjust('none') for r in _parse_requirements(txt)] 71 | ['this', 'example', 'foo # strict', 'thing'] 72 | >>> txt = '\\n'.join(txt) 73 | >>> [r.adjust('none') for r in _parse_requirements(txt)] 74 | ['this', 'example', 'foo # strict', 'thing'] 75 | 76 | """ 77 | lines = yield_lines(strs) 78 | pip_argument = None 79 | for line in lines: 80 | # Drop comments -- a hash without a space may be in a URL. 81 | if " #" in line: 82 | comment_pos = line.find(" #") 83 | line, comment = line[:comment_pos], line[comment_pos:] 84 | else: 85 | comment = "" 86 | # If there is a line continuation, drop it, and append the next line. 87 | if line.endswith("\\"): 88 | line = line[:-2].strip() 89 | try: 90 | line += next(lines) 91 | except StopIteration: 92 | return 93 | # If there's a pip argument, save it 94 | if line.startswith("--"): 95 | pip_argument = line 96 | continue 97 | if line.startswith("-r "): 98 | # linked requirement files are unsupported 99 | continue 100 | if "@" in line or re.search("https?://", line): 101 | # skip lines with links like `pesq @ git+https://github.com/ludlows/python-pesq` 102 | continue 103 | yield _RequirementWithComment(line, comment=comment, pip_argument=pip_argument) 104 | pip_argument = None 105 | 106 | 107 | def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> list[str]: 108 | """Load requirements from a file. 109 | 110 | >>> import os 111 | >>> from lightning_utilities import _PROJECT_ROOT 112 | >>> path_req = os.path.join(_PROJECT_ROOT, "requirements") 113 | >>> load_requirements(path_req, "docs.txt", unfreeze="major") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE 114 | ['sphinx<6.0,>=4.0', ...] 115 | 116 | """ 117 | if unfreeze not in {"none", "major", "all"}: 118 | raise ValueError(f'unsupported option of "{unfreeze}"') 119 | path = Path(path_dir) / file_name 120 | if not path.exists(): 121 | raise FileNotFoundError(f"missing file for {(path_dir, file_name, path)}") 122 | text = path.read_text() 123 | return [req.adjust(unfreeze) for req in _parse_requirements(text)] 124 | -------------------------------------------------------------------------------- /src/lightning_utilities/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/utilities/2658ff8b341fc0f30e9d9e9042a9d6c0e4de49a2/src/lightning_utilities/py.typed -------------------------------------------------------------------------------- /src/lightning_utilities/test/__init__.py: -------------------------------------------------------------------------------- 1 | """General testing functionality.""" 2 | -------------------------------------------------------------------------------- /src/lightning_utilities/test/warning.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # http://www.apache.org/licenses/LICENSE-2.0 3 | # 4 | import re 5 | import warnings 6 | from collections.abc import Generator 7 | from contextlib import contextmanager 8 | from typing import Optional 9 | 10 | 11 | @contextmanager 12 | def no_warning_call(expected_warning: type[Warning] = Warning, match: Optional[str] = None) -> Generator: 13 | """Check that no warning was raised/emitted under this context manager.""" 14 | with warnings.catch_warnings(record=True) as record: 15 | yield 16 | 17 | for w in record: 18 | if issubclass(w.category, expected_warning) and (match is None or re.compile(match).search(str(w.message))): 19 | raise AssertionError(f"`{expected_warning.__name__}` was raised: {w.message!r}") 20 | -------------------------------------------------------------------------------- /tests/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | _PATH_HERE = os.path.dirname(__file__) 4 | _PATH_ROOT = os.path.dirname(os.path.dirname(_PATH_HERE)) 5 | _PATH_SCRIPTS = os.path.join(_PATH_ROOT, "scripts") 6 | -------------------------------------------------------------------------------- /tests/scripts/test_adjust_torch_versions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import subprocess 4 | import sys 5 | 6 | from scripts import _PATH_SCRIPTS 7 | 8 | REQUIREMENTS_SAMPLE = """ 9 | # This is sample requirements file 10 | # with multi line comments 11 | 12 | torchvision >=0.13.0, <0.16.0 # sample # comment 13 | gym[classic,control] >=0.17.0, <0.27.0 14 | ipython[all] <8.15.0 # strict 15 | torchmetrics >=0.10.0, <1.3.0 16 | deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" # strict 17 | 18 | """ 19 | REQUIREMENTS_EXPECTED = """ 20 | # This is sample requirements file 21 | # with multi line comments 22 | 23 | torchvision==0.11.1 # sample # comment 24 | gym[classic,control] >=0.17.0, <0.27.0 25 | ipython[all] <8.15.0 # strict 26 | torchmetrics >=0.10.0, <1.3.0 27 | deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" # strict 28 | 29 | """ 30 | 31 | 32 | def test_adjust_torch_versions_call(tmp_path) -> None: 33 | path_script = os.path.join(_PATH_SCRIPTS, "adjust-torch-versions.py") 34 | path_req_file = str(tmp_path / "requirements.txt") 35 | with open(path_req_file, "w", encoding="utf8") as fopen: 36 | fopen.write(REQUIREMENTS_SAMPLE) 37 | 38 | return_code = subprocess.call([sys.executable, path_script, path_req_file, "1.10.0"]) # noqa: S603 39 | assert return_code == 0 40 | 41 | with open(path_req_file, encoding="utf8") as fopen: 42 | req_result = fopen.read() 43 | # ToDO: no idea why parsing lines on windows leave extra line after each line 44 | # tried strip, regex, hard-coded replace but none worked... so adjusting tests 45 | if platform.system() == "Windows": 46 | req_result = req_result.replace("\n\n", "\n") 47 | assert req_result == REQUIREMENTS_EXPECTED 48 | -------------------------------------------------------------------------------- /tests/unittests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | _PATH_UNITTESTS = os.path.dirname(__file__) 4 | _PATH_ROOT = os.path.dirname(os.path.dirname(_PATH_UNITTESTS)) 5 | -------------------------------------------------------------------------------- /tests/unittests/cli/test_dependencies.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from lightning_utilities.cli.dependencies import ( 4 | prune_packages_in_requirements, 5 | replace_oldest_version, 6 | replace_package_in_requirements, 7 | ) 8 | 9 | _PATH_ROOT = Path(__file__).parent.parent.parent 10 | 11 | 12 | def test_prune_packages(tmpdir): 13 | req_file = tmpdir / "requirements.txt" 14 | with open(req_file, "w") as fp: 15 | fp.writelines(["fire\n", "abc>=0.1\n"]) 16 | prune_packages_in_requirements("abc", req_files=[str(req_file)]) 17 | with open(req_file) as fp: 18 | lines = fp.readlines() 19 | assert lines == ["fire\n"] 20 | 21 | 22 | def test_oldest_packages(tmpdir): 23 | req_file = tmpdir / "requirements.txt" 24 | with open(req_file, "w") as fp: 25 | fp.writelines(["fire>0.2\n", "abc>=0.1\n"]) 26 | replace_oldest_version(req_files=[str(req_file)]) 27 | with open(req_file) as fp: 28 | lines = fp.readlines() 29 | assert lines == ["fire>0.2\n", "abc==0.1\n"] 30 | 31 | 32 | def test_replace_packages(tmpdir): 33 | req_file = tmpdir / "requirements.txt" 34 | with open(req_file, "w") as fp: 35 | fp.writelines(["torchvision>=0.2\n", "torch>=1.0 # comment\n", "torchtext <0.3\n"]) 36 | replace_package_in_requirements(old_package="torch", new_package="pytorch", req_files=[str(req_file)]) 37 | with open(req_file) as fp: 38 | lines = fp.readlines() 39 | assert lines == ["torchvision>=0.2\n", "pytorch>=1.0 # comment\n", "torchtext <0.3\n"] 40 | -------------------------------------------------------------------------------- /tests/unittests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | 3 | import os 4 | import shutil 5 | import tempfile 6 | from pathlib import Path 7 | 8 | import pytest 9 | 10 | from unittests import _PATH_ROOT 11 | 12 | _PATH_DOCS = os.path.join(_PATH_ROOT, "docs", "source") 13 | 14 | 15 | @pytest.fixture(scope="session") 16 | def temp_docs(): 17 | """Create a dummy documentation folder.""" 18 | # create a folder for docs 19 | docs_folder = Path(tempfile.mkdtemp()) 20 | # copy all real docs from _PATH_DOCS to local temp_docs 21 | for root, _, files in os.walk(_PATH_DOCS): 22 | for file in files: 23 | fpath = os.path.join(root, file) 24 | temp_path = docs_folder / os.path.relpath(fpath, _PATH_DOCS) 25 | temp_path.parent.mkdir(exist_ok=True, parents=True) 26 | with open(fpath, "rb") as fopen: 27 | temp_path.write_bytes(fopen.read()) 28 | yield str(docs_folder) 29 | # remove the folder 30 | shutil.rmtree(docs_folder.parent, ignore_errors=True) 31 | -------------------------------------------------------------------------------- /tests/unittests/core/test_enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from lightning_utilities.core.enums import StrEnum 4 | 5 | 6 | def test_consistency(): 7 | class MyEnum(StrEnum): 8 | FOO = "FOO" 9 | BAR = "BAR" 10 | BAZ = "BAZ" 11 | NUM = "32" 12 | 13 | # normal equality, case invariant 14 | assert MyEnum.FOO == "FOO" 15 | assert MyEnum.FOO == "foo" 16 | 17 | # int support 18 | assert MyEnum.NUM == 32 19 | assert MyEnum.NUM in (32, "32") 20 | 21 | # key-based 22 | assert MyEnum.from_str("num") == MyEnum.NUM 23 | 24 | # collections 25 | assert MyEnum.BAZ not in ("FOO", "BAR") 26 | assert MyEnum.BAZ in ("FOO", "BAZ") 27 | assert MyEnum.BAZ in ("baz", "FOO") 28 | assert MyEnum.BAZ not in {"BAR", "FOO"} 29 | # hash cannot be case invariant 30 | assert MyEnum.BAZ not in {"BAZ", "FOO"} 31 | assert MyEnum.BAZ in {"baz", "FOO"} 32 | 33 | 34 | def test_comparison_with_other_enum(): 35 | class MyEnum(StrEnum): 36 | FOO = "FOO" 37 | 38 | class OtherEnum(Enum): 39 | FOO = 123 40 | 41 | assert not MyEnum.FOO.__eq__(OtherEnum.FOO) 42 | 43 | 44 | def test_create_from_string(): 45 | class MyEnum(StrEnum): 46 | t1 = "T/1" 47 | T2 = "t:2" 48 | 49 | assert MyEnum.from_str("T1", source="key") 50 | assert MyEnum.try_from_str("T1", source="value") is None 51 | assert MyEnum.from_str("T1", source="any") 52 | 53 | assert MyEnum.try_from_str("T:2", source="key") is None 54 | assert MyEnum.from_str("T:2", source="value") 55 | assert MyEnum.from_str("T:2", source="any") 56 | -------------------------------------------------------------------------------- /tests/unittests/core/test_imports.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import re 3 | from unittest import mock 4 | from unittest.mock import Mock 5 | 6 | import pytest 7 | 8 | from lightning_utilities.core.imports import ( 9 | RequirementCache, 10 | compare_version, 11 | get_dependency_min_version_spec, 12 | lazy_import, 13 | module_available, 14 | requires, 15 | ) 16 | 17 | try: 18 | from importlib.metadata import PackageNotFoundError 19 | except ImportError: 20 | # Python < 3.8 21 | from importlib_metadata import PackageNotFoundError 22 | 23 | 24 | def test_module_exists(): 25 | assert module_available("_pytest") 26 | assert module_available("_pytest.mark.structures") 27 | assert not module_available("_pytest.mark.asdf") 28 | assert not module_available("asdf") 29 | assert not module_available("asdf.bla.asdf") 30 | 31 | 32 | def testcompare_version(monkeypatch): 33 | monkeypatch.setattr(pytest, "__version__", "1.8.9") 34 | assert not compare_version("pytest", operator.ge, "1.10.0") 35 | assert compare_version("pytest", operator.lt, "1.10.0") 36 | 37 | monkeypatch.setattr(pytest, "__version__", "1.10.0.dev123") 38 | assert compare_version("pytest", operator.ge, "1.10.0.dev123") 39 | assert not compare_version("pytest", operator.ge, "1.10.0.dev124") 40 | 41 | assert compare_version("pytest", operator.ge, "1.10.0.dev123", use_base_version=True) 42 | assert compare_version("pytest", operator.ge, "1.10.0.dev124", use_base_version=True) 43 | 44 | monkeypatch.setattr(pytest, "__version__", "1.10.0a0+0aef44c") # dev version before rc 45 | assert compare_version("pytest", operator.ge, "1.10.0.rc0", use_base_version=True) 46 | assert not compare_version("pytest", operator.ge, "1.10.0.rc0") 47 | assert compare_version("pytest", operator.ge, "1.10.0", use_base_version=True) 48 | assert not compare_version("pytest", operator.ge, "1.10.0") 49 | 50 | 51 | def test_requirement_cache(): 52 | assert RequirementCache(f"pytest>={pytest.__version__}") 53 | assert not RequirementCache(f"pytest<{pytest.__version__}") 54 | assert "pip install -U 'not-found-requirement'" in str(RequirementCache("not-found-requirement")) 55 | 56 | # invalid requirement is skipped by valid module 57 | assert RequirementCache(f"pytest<{pytest.__version__}", "pytest") 58 | 59 | cache = RequirementCache("this_module_is_not_installed") 60 | assert not cache 61 | assert "pip install -U 'this_module_is_not_installed" in str(cache) 62 | 63 | cache = RequirementCache("this_module_is_not_installed", "this_also_is_not") 64 | assert not cache 65 | assert "pip install -U 'this_module_is_not_installed" in str(cache) 66 | 67 | cache = RequirementCache("pytest[not-valid-extra]") 68 | assert not cache 69 | assert "pip install -U 'pytest[not-valid-extra]" in str(cache) 70 | 71 | 72 | @mock.patch("lightning_utilities.core.imports.Requirement") 73 | @mock.patch("lightning_utilities.core.imports._version") 74 | @mock.patch("lightning_utilities.core.imports.distribution") 75 | def test_requirement_cache_with_extras(distribution_mock, version_mock, requirement_mock): 76 | requirement_mock().specifier.contains.return_value = True 77 | requirement_mock().name = "jsonargparse" 78 | requirement_mock().extras = [] 79 | version_mock.return_value = "1.0.0" 80 | assert RequirementCache("jsonargparse>=1.0.0") 81 | 82 | with mock.patch("lightning_utilities.core.imports.RequirementCache._get_extra_requirements") as get_extra_req_mock: 83 | get_extra_req_mock.return_value = [ 84 | # Extra packages, all versions satisfied 85 | Mock(name="extra_package1", specifier=Mock(contains=Mock(return_value=True))), 86 | Mock(name="extra_package2", specifier=Mock(contains=Mock(return_value=True))), 87 | ] 88 | distribution_mock.return_value = Mock(version="0.10.0") 89 | requirement_mock().extras = ["signatures"] 90 | assert RequirementCache("jsonargparse[signatures]>=1.0.0") 91 | 92 | with mock.patch("lightning_utilities.core.imports.RequirementCache._get_extra_requirements") as get_extra_req_mock: 93 | get_extra_req_mock.return_value = [ 94 | # Extra packages, but not all versions are satisfied 95 | Mock(name="extra_package1", specifier=Mock(contains=Mock(return_value=True))), 96 | Mock(name="extra_package2", specifier=Mock(contains=Mock(return_value=False))), 97 | ] 98 | distribution_mock.return_value = Mock(version="0.10.0") 99 | requirement_mock().extras = ["signatures"] 100 | assert not RequirementCache("jsonargparse[signatures]>=1.0.0") 101 | 102 | 103 | @mock.patch("lightning_utilities.core.imports._version") 104 | def test_requirement_cache_with_prerelease_package(version_mock): 105 | version_mock.return_value = "0.11.0" 106 | assert RequirementCache("transformer-engine>=0.11.0") 107 | version_mock.return_value = "0.11.0.dev0+931b44f" 108 | assert not RequirementCache("transformer-engine>=0.11.0") 109 | version_mock.return_value = "1.10.0.dev0+931b44f" 110 | assert RequirementCache("transformer-engine>=0.11.0") 111 | 112 | 113 | def test_module_available_cache(): 114 | assert RequirementCache(module="pytest") 115 | assert not RequirementCache(module="this_module_is_not_installed") 116 | assert "pip install -U this_module_is_not_installed" in str(RequirementCache(module="this_module_is_not_installed")) 117 | 118 | 119 | def test_get_dependency_min_version_spec(): 120 | attrs_min_version_spec = get_dependency_min_version_spec("pytest", "attrs") 121 | assert re.match(r"^>=[\d.]+$", attrs_min_version_spec) 122 | 123 | with pytest.raises(ValueError, match="'invalid' not found in package 'pytest'"): 124 | get_dependency_min_version_spec("pytest", "invalid") 125 | 126 | with pytest.raises(PackageNotFoundError, match="invalid"): 127 | get_dependency_min_version_spec("invalid", "invalid") 128 | 129 | 130 | def test_lazy_import(): 131 | def callback_fcn(): 132 | raise ValueError 133 | 134 | math = lazy_import("math", callback=callback_fcn) 135 | with pytest.raises(ValueError, match=""): # noqa: PT011 136 | math.floor(5.1) 137 | 138 | module = lazy_import("asdf") 139 | with pytest.raises(ModuleNotFoundError, match="No module named 'asdf'"): 140 | print(module) 141 | 142 | os = lazy_import("os") 143 | assert os.getcwd() 144 | 145 | 146 | @requires("torch.unknown.subpackage") 147 | def my_torch_func(i: int) -> int: 148 | import torch # noqa 149 | 150 | return i 151 | 152 | 153 | def test_torch_func_raised(): 154 | with pytest.raises( 155 | ModuleNotFoundError, 156 | match="Required dependencies not available: \nModule not found: 'torch.unknown.subpackage'. ", 157 | ): 158 | my_torch_func(42) 159 | 160 | 161 | @requires("random") 162 | def my_random_func(nb: int) -> int: 163 | from random import randint 164 | 165 | return randint(0, nb) 166 | 167 | 168 | def test_rand_func_passed(): 169 | assert 0 <= my_random_func(42) <= 42 170 | 171 | 172 | class MyTorchClass: 173 | @requires("torch>99.0", "random") 174 | def __init__(self): 175 | from random import randint 176 | 177 | import torch # noqa 178 | 179 | self._rnd = randint(1, 9) 180 | 181 | 182 | def test_torch_class_raised(): 183 | with pytest.raises( 184 | ModuleNotFoundError, match="Required dependencies not available: \nModule not found: 'torch>99.0'." 185 | ): 186 | MyTorchClass() 187 | 188 | 189 | class MyRandClass: 190 | @requires("random") 191 | def __init__(self, nb: int): 192 | from random import randint 193 | 194 | self._rnd = randint(1, nb) 195 | 196 | 197 | def test_rand_class_passed(): 198 | cls = MyRandClass(42) 199 | assert 0 <= cls._rnd <= 42 200 | -------------------------------------------------------------------------------- /tests/unittests/core/test_inheritance.py: -------------------------------------------------------------------------------- 1 | from lightning_utilities.core.inheritance import get_all_subclasses 2 | 3 | 4 | def test_get_all_subclasses(): 5 | class A1: ... 6 | 7 | class A2(A1): ... 8 | 9 | class B1: ... 10 | 11 | class B2(B1): ... 12 | 13 | class C(A2, B2): ... 14 | 15 | assert get_all_subclasses(A1) == {A2, C} 16 | assert get_all_subclasses(A2) == {C} 17 | assert get_all_subclasses(B1) == {B2, C} 18 | assert get_all_subclasses(B2) == {C} 19 | assert get_all_subclasses(C) == set() 20 | -------------------------------------------------------------------------------- /tests/unittests/core/test_overrides.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from functools import partial, wraps 3 | from typing import Any, Callable 4 | from unittest.mock import Mock 5 | 6 | import pytest 7 | 8 | from lightning_utilities.core.overrides import is_overridden 9 | 10 | 11 | class LightningModule: 12 | def training_step(self): ... 13 | 14 | 15 | class BoringModel(LightningModule): 16 | def training_step(self): ... 17 | 18 | 19 | class Strategy: 20 | @contextmanager 21 | def model_sharded_context(): ... 22 | 23 | 24 | class SingleDeviceStrategy(Strategy): ... 25 | 26 | 27 | def test_is_overridden(): 28 | assert not is_overridden("whatever", object(), parent=LightningModule) 29 | 30 | class TestModel(BoringModel): 31 | def foo(self): 32 | pass 33 | 34 | def bar(self): 35 | return 1 36 | 37 | with pytest.raises(ValueError, match="The parent should define the method"): 38 | is_overridden("foo", TestModel(), parent=BoringModel) 39 | 40 | # normal usage 41 | assert is_overridden("training_step", BoringModel(), parent=LightningModule) 42 | 43 | # reversed. works even without inheritance 44 | assert is_overridden("training_step", LightningModule(), parent=BoringModel) 45 | 46 | class WrappedModel(TestModel): 47 | def __new__(cls, *args: Any, **kwargs: Any): 48 | obj = super().__new__(cls) 49 | obj.foo = cls.wrap(obj.foo) 50 | obj.bar = cls.wrap(obj.bar) 51 | return obj 52 | 53 | @staticmethod 54 | def wrap(fn) -> Callable: 55 | @wraps(fn) 56 | def wrapper(): 57 | fn() 58 | 59 | return wrapper 60 | 61 | def bar(self): 62 | return 2 63 | 64 | # `functools.wraps()` support 65 | assert not is_overridden("foo", WrappedModel(), parent=TestModel) 66 | assert is_overridden("bar", WrappedModel(), parent=TestModel) 67 | 68 | # `Mock` support 69 | mock = Mock(spec=BoringModel, wraps=BoringModel()) 70 | assert is_overridden("training_step", mock, parent=LightningModule) 71 | 72 | # `partial` support 73 | model = BoringModel() 74 | model.training_step = partial(model.training_step) 75 | assert is_overridden("training_step", model, parent=LightningModule) 76 | 77 | # `@contextmanager` support 78 | assert not is_overridden("model_sharded_context", SingleDeviceStrategy(), Strategy) 79 | -------------------------------------------------------------------------------- /tests/unittests/core/test_rank_zero.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only 4 | 5 | 6 | def test_rank_zero_only_raises(): 7 | foo = rank_zero_only(lambda x: x + 1) 8 | with pytest.raises(RuntimeError, match="rank_zero_only.rank` needs to be set "): 9 | foo(1) 10 | 11 | 12 | @pytest.mark.parametrize("rank", [0, 1, 4]) 13 | def test_rank_prefixed_message(rank): 14 | rank_zero_only.rank = rank 15 | message = rank_prefixed_message("bar", rank) 16 | assert message == f"[rank: {rank}] bar" 17 | # reset 18 | del rank_zero_only.rank 19 | 20 | 21 | def test_rank_zero_only_default(): 22 | foo = lambda: "foo" 23 | rank_zero_foo = rank_zero_only(foo, "not foo") 24 | 25 | rank_zero_only.rank = 0 26 | assert rank_zero_foo() == "foo" 27 | 28 | rank_zero_only.rank = 1 29 | assert rank_zero_foo() == "not foo" 30 | -------------------------------------------------------------------------------- /tests/unittests/docs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/utilities/2658ff8b341fc0f30e9d9e9042a9d6c0e4de49a2/tests/unittests/docs/__init__.py -------------------------------------------------------------------------------- /tests/unittests/docs/test_formatting.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import re 3 | 4 | import pytest 5 | 6 | from lightning_utilities.docs import adjust_linked_external_docs 7 | 8 | 9 | @pytest.mark.online 10 | def test_adjust_linked_external_docs(temp_docs): 11 | # take config as it includes API references with `stable` 12 | path_conf = os.path.join(temp_docs, "conf.py") 13 | path_testpage = os.path.join(temp_docs, "test-page.rst") 14 | 15 | def _get_line_with_numpy(path_rst: str, pattern: str) -> str: 16 | with open(path_rst, encoding="UTF-8") as fopen: 17 | lines = fopen.readlines() 18 | # find the first line with figure reference 19 | return next(ln for ln in lines if pattern in ln) 20 | 21 | # validate the initial expectations 22 | line = _get_line_with_numpy(path_conf, pattern='"numpy":') 23 | assert "https://numpy.org/doc/stable/" in line 24 | line = _get_line_with_numpy(path_testpage, pattern="Link to scikit-learn stable documentation:") 25 | assert "https://scikit-learn.org/stable/" in line 26 | 27 | adjust_linked_external_docs( 28 | "https://numpy.org/doc/stable/", "https://numpy.org/doc/{numpy.__version__}/", temp_docs 29 | ) 30 | adjust_linked_external_docs( 31 | "https://scikit-learn.org/stable/", "https://scikit-learn.org/{scikit-learn}/", temp_docs 32 | ) 33 | 34 | # validate the final state of index page 35 | line = _get_line_with_numpy(path_conf, pattern='"numpy":') 36 | assert re.search(r"https://numpy.org/doc/([1-9]\d*)\.(\d+)/", line) 37 | line = _get_line_with_numpy(path_testpage, pattern="Link to scikit-learn stable documentation:") 38 | assert re.search(r"https://scikit-learn.org/([1-9]\d*)\.(\d+)/", line) 39 | -------------------------------------------------------------------------------- /tests/unittests/docs/test_retriever.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import shutil 3 | 4 | import pytest 5 | 6 | from lightning_utilities.docs import fetch_external_assets 7 | 8 | 9 | @pytest.mark.online 10 | def test_retriever_s3(temp_docs): 11 | # take the index page 12 | path_index = os.path.join(temp_docs, "index.rst") 13 | # copy it to another location to test depth 14 | path_page = os.path.join(temp_docs, "any", "extra", "page.rst") 15 | os.makedirs(os.path.dirname(path_page), exist_ok=True) 16 | shutil.copy(path_index, path_page) 17 | 18 | def _get_line_with_figure(path_rst: str) -> str: 19 | with open(path_rst, encoding="UTF-8") as fopen: 20 | lines = fopen.readlines() 21 | # find the first line with figure reference 22 | return next(ln for ln in lines if ln.startswith(".. figure::")) 23 | 24 | # validate the initial expectations 25 | line = _get_line_with_figure(path_index) 26 | # that the image exists 27 | assert "Lightning.gif" in line 28 | # and it is sourced in S3 29 | assert ".s3." in line 30 | 31 | fetch_external_assets(docs_folder=temp_docs) 32 | 33 | # validate the final state of index page 34 | line = _get_line_with_figure(path_index) 35 | # that the image exists 36 | assert os.path.join("fetched-s3-assets", "Lightning.gif") in line 37 | # but it is not sourced from S3 38 | assert ".s3." not in line 39 | 40 | # validate the final state of additional page 41 | line = _get_line_with_figure(path_page) 42 | # that the image exists in the proper depth 43 | assert os.path.join("..", "..", "fetched-s3-assets", "Lightning.gif") in line 44 | # but it is not sourced from S3 45 | assert ".s3." not in line 46 | -------------------------------------------------------------------------------- /tests/unittests/mocks.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | from typing import Any 3 | 4 | from lightning_utilities.core.imports import package_available 5 | 6 | if package_available("torch"): 7 | import torch 8 | else: 9 | # minimal torch implementation to avoid installing torch in testing CI 10 | class TensorMock: 11 | def __init__(self, data) -> None: 12 | self.data = data 13 | 14 | def __add__(self, other): 15 | """Perform and operation.""" 16 | if isinstance(self.data, Iterable): 17 | if isinstance(other, (int, float)): 18 | return TensorMock([a + other for a in self.data]) 19 | if isinstance(other, Iterable): 20 | return TensorMock([a + b for a, b in zip(self, other)]) 21 | return self.data + other 22 | 23 | def __mul__(self, other): 24 | """Perform mul operation.""" 25 | if isinstance(self.data, Iterable): 26 | if isinstance(other, (int, float)): 27 | return TensorMock([a * other for a in self.data]) 28 | if isinstance(other, Iterable): 29 | return TensorMock([a * b for a, b in zip(self, other)]) 30 | return self.data * other 31 | 32 | def __iter__(self): 33 | """Iterate.""" 34 | return iter(self.data) 35 | 36 | def __repr__(self) -> str: 37 | """Return object representation.""" 38 | return repr(self.data) 39 | 40 | def __eq__(self, other): 41 | """Perform equal operation.""" 42 | return self.data == other 43 | 44 | def add_(self, value): 45 | self.data += value 46 | return self.data 47 | 48 | class TorchMock: 49 | Tensor = TensorMock 50 | 51 | @staticmethod 52 | def tensor(data: Any) -> TensorMock: 53 | return TensorMock(data) 54 | 55 | @staticmethod 56 | def equal(a: Any, b: Any) -> bool: 57 | return a == b 58 | 59 | @staticmethod 60 | def arange(*args: Any) -> TensorMock: 61 | return TensorMock(list(range(*args))) 62 | 63 | torch = TorchMock() 64 | -------------------------------------------------------------------------------- /tests/unittests/test/test_warnings.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from re import escape 3 | 4 | import pytest 5 | 6 | from lightning_utilities.test.warning import no_warning_call 7 | 8 | 9 | def test_no_warning_call(): 10 | with no_warning_call(): 11 | ... 12 | 13 | with pytest.raises(AssertionError, match=escape("`Warning` was raised: UserWarning('foo')")), no_warning_call(): 14 | warnings.warn("foo") 15 | 16 | with no_warning_call(DeprecationWarning): 17 | warnings.warn("foo") 18 | 19 | class MyDeprecationWarning(DeprecationWarning): ... 20 | 21 | with ( 22 | pytest.raises(AssertionError, match=escape("`DeprecationWarning` was raised: MyDeprecationWarning('bar')")), 23 | no_warning_call(DeprecationWarning), 24 | ): 25 | warnings.warn("bar", category=MyDeprecationWarning) 26 | --------------------------------------------------------------------------------