├── .codecov.yml
├── .github
├── CODEOWNERS
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ ├── documentation.md
│ └── feature_request.md
├── PULL_REQUEST_TEMPLATE.md
├── actions
│ ├── cache
│ │ └── action.yml
│ ├── pip-list
│ │ └── action.yml
│ ├── pkg-create
│ │ └── action.yml
│ ├── pkg-install
│ │ └── action.yml
│ └── unittesting
│ │ └── action.yml
├── dependabot.yml
├── labeler-config.yml
├── markdown.links.config.json
├── mergify.yml
├── scripts
│ ├── find-unused-caches.py
│ └── find-unused-caches.txt
├── stale.yml
└── workflows
│ ├── check-docs.yml
│ ├── check-md-links.yml
│ ├── check-package.yml
│ ├── check-precommit.yml
│ ├── check-schema.yml
│ ├── check-typing.yml
│ ├── ci-cli.yml
│ ├── ci-rtfd.yml
│ ├── ci-scripts.yml
│ ├── ci-testing.yml
│ ├── ci-use-checks.yaml
│ ├── cleanup-caches.yml
│ ├── cron-clear-cache.yml
│ ├── deploy-docs.yml
│ ├── label-pr.yml
│ └── release-pypi.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CHANGELOG.md
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── docs
├── .build_docs.sh
├── .readthedocs.yaml
├── Makefile
├── make.bat
└── source
│ ├── _static
│ ├── copybutton.js
│ └── images
│ │ ├── icon.svg
│ │ ├── logo-large.svg
│ │ ├── logo-small.svg
│ │ ├── logo.png
│ │ └── logo.svg
│ ├── _templates
│ └── theme_variables.jinja
│ ├── conf.py
│ ├── index.rst
│ └── test-page.rst
├── pyproject.toml
├── requirements
├── _docs.txt
├── _tests.txt
├── cli.txt
├── core.txt
├── docs.txt
├── gha-package.txt
├── gha-schema.txt
└── typing.txt
├── scripts
├── adjust-torch-versions.py
├── inject-selector-script.py
└── run_standalone_tests.sh
├── setup.py
├── src
└── lightning_utilities
│ ├── __about__.py
│ ├── __init__.py
│ ├── cli
│ ├── __init__.py
│ ├── __main__.py
│ └── dependencies.py
│ ├── core
│ ├── __init__.py
│ ├── apply_func.py
│ ├── enums.py
│ ├── imports.py
│ ├── inheritance.py
│ ├── overrides.py
│ └── rank_zero.py
│ ├── docs
│ ├── __init__.py
│ ├── formatting.py
│ └── retriever.py
│ ├── install
│ ├── __init__.py
│ └── requirements.py
│ ├── py.typed
│ └── test
│ ├── __init__.py
│ └── warning.py
└── tests
├── scripts
├── __init__.py
└── test_adjust_torch_versions.py
└── unittests
├── __init__.py
├── cli
└── test_dependencies.py
├── conftest.py
├── core
├── test_apply_func.py
├── test_enums.py
├── test_imports.py
├── test_inheritance.py
├── test_overrides.py
└── test_rank_zero.py
├── docs
├── __init__.py
├── test_formatting.py
└── test_retriever.py
├── mocks.py
└── test
└── test_warnings.py
/.codecov.yml:
--------------------------------------------------------------------------------
1 | # see https://docs.codecov.io/docs/codecov-yaml
2 | # Validation check:
3 | # $ curl --data-binary @.codecov.yml https://codecov.io/validate
4 |
5 | # https://docs.codecov.io/docs/codecovyml-reference
6 | codecov:
7 | bot: "codecov-io"
8 | strict_yaml_branch: "yaml-config"
9 | require_ci_to_pass: yes
10 | notify:
11 | # after_n_builds: 2
12 | wait_for_ci: yes
13 |
14 | coverage:
15 | precision: 0 # 2 = xx.xx%, 0 = xx%
16 | round: nearest # how coverage is rounded: down/up/nearest
17 | range: 40...100 # custom range of coverage colors from red -> yellow -> green
18 | status:
19 | # https://codecov.readme.io/v1.0/docs/commit-status
20 | project:
21 | default:
22 | informational: true
23 | target: 99% # specify the target coverage for each commit status
24 | threshold: 30% # allow this little decrease on project
25 | # https://github.com/codecov/support/wiki/Filtering-Branches
26 | # branches: main
27 | if_ci_failed: error
28 | # https://github.com/codecov/support/wiki/Patch-Status
29 | patch:
30 | default:
31 | informational: true
32 | target: 50% # specify the target "X%" coverage to hit
33 | # threshold: 50% # allow this much decrease on patch
34 | changes: false
35 |
36 | # https://docs.codecov.com/docs/github-checks#disabling-github-checks-patch-annotations
37 | github_checks:
38 | annotations: false
39 |
40 | parsers:
41 | gcov:
42 | branch_detection:
43 | conditional: true
44 | loop: true
45 | macro: false
46 | method: false
47 | javascript:
48 | enable_partials: false
49 |
50 | comment:
51 | layout: header, diff
52 | require_changes: false
53 | behavior: default # update if exists else create new
54 | # branches: *
55 |
--------------------------------------------------------------------------------
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # This is a comment.
2 | # Each line is a file pattern followed by one or more owners.
3 |
4 | # These owners will be the default owners for everything in
5 | # the repo. Unless a later match takes precedence,
6 | # @global-owner1 and @global-owner2 will be requested for
7 | # review when someone opens a pull request.
8 | * @borda @ethanwharris @justusschock @tchaton
9 |
10 | # CI/CD and configs
11 | /.github/ @borda @ethanwharris @justusschock @tchaton
12 | *.yml @borda @ethanwharris @justusschock @tchaton
13 |
14 | # Docs
15 | /docs/ @borda
16 | /.github/*.md @borda
17 | /.github/ISSUE_TEMPLATE/ @borda
18 |
19 | /.github/CODEOWNERS @borda
20 | /setup.py @borda
21 |
22 | /src @borda @ethanwharris @justusschock @tchaton
23 | /tests/unittests @borda @ethanwharris @justusschock @tchaton
24 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: bug, help wanted
6 | assignees: ''
7 | ---
8 |
9 | ## 🐛 Bug
10 |
11 |
12 |
13 | ### To Reproduce
14 |
15 | Steps to reproduce the behavior...
16 |
17 |
18 |
19 |
20 | Code sample
21 |
22 |
24 |
25 |
26 |
27 | ### Expected behavior
28 |
29 |
30 |
31 | ### Additional context
32 |
33 |
34 | Environment details
35 |
36 | - PyTorch Version (e.g., 1.0):
37 | - OS (e.g., Linux):
38 | - How you installed PyTorch (`conda`, `pip`, source):
39 | - Build command you used (if compiling from source):
40 | - Python version:
41 | - CUDA/cuDNN version:
42 | - GPU models and configuration:
43 | - Any other relevant information:
44 |
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/documentation.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Typos and doc fixes
3 | about: Typos and doc fixes
4 | title: ''
5 | labels: documentation
6 | assignees: ''
7 | ---
8 |
9 | ## 📚 Documentation
10 |
11 | For typos or docs fixes, please go ahead and submit a PR (no need to open an issue).
12 | If you are not sure about the proper solution, please describe the issue here...
13 |
14 | Thanks!
15 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: enhancement, help wanted
6 | assignees: ''
7 | ---
8 |
9 | ## 🚀 Feature
10 |
11 |
12 |
13 | ### Motivation
14 |
15 |
16 |
17 | ### Alternatives
18 |
19 |
20 |
21 | ### Additional context
22 |
23 |
24 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 |
2 | Before submitting
3 |
4 | - [ ] Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
5 | - [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/lightning/blob/master/.github/CONTRIBUTING.md), Pull Request section?
6 | - Did you make sure to update the docs?
7 | - [ ] Did all existing and newly added tests pass locally?
8 |
9 |
10 |
11 | ## What does this PR do?
12 |
13 | Fixes # (issue).
14 |
15 | ## PR review
16 |
17 | Anyone in the community is free to review the PR once the tests have passed.
18 | If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged.
19 |
20 |
25 |
--------------------------------------------------------------------------------
/.github/actions/cache/action.yml:
--------------------------------------------------------------------------------
1 | name: Complex caching
2 | description: some more complex caching - pip & conda
3 |
4 | inputs:
5 | python-version:
6 | description: Python version
7 | required: true
8 | requires:
9 | description: define oldest, latest, or an empty string
10 | required: false
11 | default: ""
12 | offset:
13 | description: some extra hash for pip cache
14 | required: false
15 | default: ""
16 | interval:
17 | description: cache hash reset interval in days
18 | required: false
19 | default: "7"
20 |
21 | runs:
22 | using: "composite"
23 | steps:
24 | - name: Determine caches
25 | id: cache_dirs
26 | run: echo "pip_dir=$(pip cache dir)" >> $GITHUB_OUTPUT
27 | shell: bash
28 |
29 | - name: Cache 💽 pip
30 | uses: actions/cache@v3
31 | with:
32 | path: ${{ steps.cache_dirs.outputs.pip_dir }}
33 | key: py${{ inputs.python-version }}-pip-${{ inputs.offset }}-${{ hashFiles('requirements.txt') }}
34 | restore-keys: py${{ inputs.python-version }}-pip-${{ inputs.offset }}-
35 | enableCrossOsArchive: true
36 |
37 | - name: Cache 💽 conda
38 | uses: actions/cache@v3
39 | if: runner.os == 'Linux'
40 | with:
41 | path: ~/conda_pkgs_dir
42 | key: py${{ inputs.python-version }}-conda-${{ inputs.offset }}
43 | restore-keys: py${{ inputs.python-version }}-conda-${{ inputs.offset }}
44 |
--------------------------------------------------------------------------------
/.github/actions/pip-list/action.yml:
--------------------------------------------------------------------------------
1 | name: Print pip list
2 | description: Print pip list for quick access to environment information
3 |
4 | inputs:
5 | unfold:
6 | description: Whether to unfold the output of pip list
7 | required: false
8 | default: "false"
9 |
10 | runs:
11 | using: "composite"
12 | steps:
13 | - name: pip list
14 | run: |
15 | if [ "${{ inputs.unfold }}" = "true" ]; then
16 | echo '' >> $GITHUB_STEP_SUMMARY
17 | else
18 | echo '' >> $GITHUB_STEP_SUMMARY
19 | fi
20 | echo 'pip list
' >> $GITHUB_STEP_SUMMARY
21 | echo '' >> $GITHUB_STEP_SUMMARY
22 | echo '```' >> $GITHUB_STEP_SUMMARY
23 | pip list >> $GITHUB_STEP_SUMMARY
24 | echo '```' >> $GITHUB_STEP_SUMMARY
25 | echo '' >> $GITHUB_STEP_SUMMARY
26 | echo ' ' >> $GITHUB_STEP_SUMMARY
27 | pip list # also print to stdout
28 | shell: bash
29 |
--------------------------------------------------------------------------------
/.github/actions/pkg-create/action.yml:
--------------------------------------------------------------------------------
1 | name: Create and check package
2 | description: building, checking the package
3 |
4 | runs:
5 | using: "composite"
6 | steps:
7 | - name: Create package 📦
8 | # python setup.py clean
9 | run: python -m build --verbose
10 | shell: bash
11 |
12 | - name: Check package 📦
13 | working-directory: dist
14 | run: |
15 | ls -lh .
16 | twine check * --strict
17 | shell: bash
18 |
--------------------------------------------------------------------------------
/.github/actions/pkg-install/action.yml:
--------------------------------------------------------------------------------
1 | name: Install and check package
2 | description: installing and validation the package
3 |
4 | inputs:
5 | install-from:
6 | description: "Define if the package is from archive or wheel"
7 | required: true
8 | pkg-folder:
9 | description: "Unique name for collecting artifacts"
10 | required: false
11 | default: "pypi"
12 | pkg-extras:
13 | description: "optional extras which are needed to include also []"
14 | required: false
15 | default: ""
16 | pip-flags:
17 | description: "additional pip install flags"
18 | required: false
19 | default: "-f https://download.pytorch.org/whl/cpu/torch_stable.html"
20 | import-name:
21 | description: "Import name to test with after installation"
22 | required: true
23 | custom-import-code:
24 | description: "additional import statement, need to be full python code"
25 | required: false
26 | default: ""
27 |
28 | runs:
29 | using: "composite"
30 | steps:
31 | - name: show packages
32 | working-directory: ${{ inputs.pkg-folder }}
33 | run: |
34 | ls -lh
35 | pip -V
36 | echo "PKG_WHEEL=$(ls *.whl | head -n1)" >> $GITHUB_ENV
37 | echo "PKG_SOURCE=$(ls *.tar.gz | head -n1)" >> $GITHUB_ENV
38 | pip list
39 | shell: bash
40 |
41 | - name: Install package (archive)
42 | if: ${{ inputs.install-from == 'archive' }}
43 | working-directory: pypi/
44 | run: |
45 | set -ex
46 | pip install '${{ env.PKG_SOURCE }}${{ inputs.pkg-extras }}' \
47 | --force-reinstall ${{ inputs.pip-flags }}
48 | pip list
49 | shell: bash
50 |
51 | - name: Install package (wheel)
52 | if: ${{ inputs.install-from == 'wheel' }}
53 | working-directory: ${{ inputs.pkg-folder }}
54 | run: |
55 | set -ex
56 | pip install '${{ env.PKG_WHEEL }}${{ inputs.pkg-extras }}' \
57 | --force-reinstall ${{ inputs.pip-flags }}
58 | pip list
59 | shell: bash
60 |
61 | - name: package check / import
62 | if: ${{ inputs.import-name != '' }}
63 | run: |
64 | python -c "import ${{ inputs.import-name }} as pkg; print(f'version: {pkg.__version__}')"
65 | shell: bash
66 |
67 | - name: package check / custom import
68 | if: ${{ inputs.custom-import-code != '' }}
69 | run: |
70 | python -c '${{ inputs.custom-import-code }}'
71 | shell: bash
72 |
73 | - name: Uninstall all
74 | # TODO: reset env / consider add as conda
75 | run: |
76 | pip freeze > _reqs.txt
77 | pip uninstall -y -r _reqs.txt
78 | shell: bash
79 |
--------------------------------------------------------------------------------
/.github/actions/unittesting/action.yml:
--------------------------------------------------------------------------------
1 | name: Unittest and coverage
2 | description: pull data samples -> unittests
3 |
4 | inputs:
5 | python-version:
6 | description: Python version
7 | required: true
8 | pkg-name:
9 | description: package name for coverage collections
10 | required: true
11 | requires:
12 | description: define oldest or latest
13 | required: false
14 | default: ""
15 | dirs:
16 | description: Testing folders per domains, space separated string
17 | required: false
18 | default: "."
19 | pytest-args:
20 | description: Additional pytest arguments such as `--timeout=120`
21 | required: false
22 | default: ""
23 | shell-type:
24 | description: Define Shell type
25 | required: false
26 | default: "bash"
27 |
28 | runs:
29 | using: "composite"
30 | steps:
31 | - name: Python 🐍 details
32 | run: |
33 | python --version
34 | pip --version
35 | pip list
36 | shell: ${{ inputs.shell-type }}
37 |
38 | - name: Determine artifact file name
39 | run: echo "artifact=test-results-${{ runner.os }}-py${{ inputs.python-version }}-${{ inputs.requires }}" >> $GITHUB_OUTPUT
40 | id: location
41 | shell: bash
42 |
43 | - name: Unittests
44 | working-directory: ./tests
45 | run: |
46 | python -m pytest ${{ inputs.dirs }} \
47 | --cov=${{ inputs.pkg-name }} --durations=50 ${{ inputs.test-timeout }} \
48 | --junitxml="${{ steps.location.outputs.artifact }}.xml"
49 | shell: ${{ inputs.shell-type }}
50 |
51 | - name: Upload pytest results
52 | uses: actions/upload-artifact@v4
53 | with:
54 | name: ${{ steps.location.outputs.artifact }}
55 | path: "test/${{ steps.location.outputs.artifact }}.xml"
56 | include-hidden-files: true
57 | if: failure()
58 |
59 | - name: Statistics
60 | if: success()
61 | working-directory: ./tests
62 | run: |
63 | coverage xml
64 | coverage report
65 | shell: ${{ inputs.shell-type }}
66 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
2 | version: 2
3 | updates:
4 | # Enable version updates for python
5 | - package-ecosystem: "pip"
6 | directory: "/requirements"
7 | schedule:
8 | interval: "weekly"
9 | labels: ["requirements", "enhancement"]
10 | pull-request-branch-name:
11 | separator: "-"
12 | open-pull-requests-limit: 5
13 |
14 | # Enable version updates for GitHub Actions
15 | - package-ecosystem: "github-actions"
16 | directory: "/"
17 | schedule:
18 | interval: "weekly"
19 | labels: ["ci/cd", "enhancement"]
20 | pull-request-branch-name:
21 | separator: "-"
22 | open-pull-requests-limit: 5
23 |
--------------------------------------------------------------------------------
/.github/labeler-config.yml:
--------------------------------------------------------------------------------
1 | documentation:
2 | - changed-files:
3 | - any-glob-to-any-file:
4 | - docs/**/*
5 | - README.md
6 | - requirements/_docs.txt
7 |
8 | ci/cd:
9 | - changed-files:
10 | - any-glob-to-any-file:
11 | - .github/actions/**/*
12 | - .github/scripts/*
13 | - .github/workflows/*
14 | - .pre-commit-config.yaml
15 |
16 | package:
17 | - changed-files:
18 | - any-glob-to-any-file:
19 | - src/**/*
20 | - MANIFEST.in
21 | - pyproject.toml
22 | - setup.py
23 |
24 | tests:
25 | - changed-files:
26 | - any-glob-to-any-file:
27 | - tests/**/*
28 | - requirements/_tests.txt
29 |
30 | dependencies:
31 | - changed-files:
32 | - any-glob-to-any-file:
33 | - requirements/*
34 |
35 | release:
36 | - base-branch: "release/*"
37 |
--------------------------------------------------------------------------------
/.github/markdown.links.config.json:
--------------------------------------------------------------------------------
1 | {
2 | "ignorePatterns": [
3 | {
4 | "pattern": "^https://github.com/Lightning-AI/lightning/pull/.*"
5 | }
6 | ],
7 | "httpHeaders": [
8 | {
9 | "urls": [
10 | "https://github.com/",
11 | "https://guides.github.com/",
12 | "https://help.github.com/",
13 | "https://docs.github.com/"
14 | ],
15 | "headers": {
16 | "Accept-Encoding": "zstd, br, gzip, deflate"
17 | }
18 | }
19 | ]
20 | }
21 |
--------------------------------------------------------------------------------
/.github/mergify.yml:
--------------------------------------------------------------------------------
1 | pull_request_rules:
2 | - name: warn on conflicts
3 | conditions:
4 | - conflict
5 | - -draft # filter-out GH draft PRs
6 | - -label="has conflicts"
7 | actions:
8 | # comment:
9 | # message: This pull request is now in conflict... :(
10 | label:
11 | add: ["has conflicts"]
12 |
13 | - name: resolved conflicts
14 | conditions:
15 | - -conflict
16 | - label="has conflicts"
17 | - -draft # filter-out GH draft PRs
18 | - -merged # not merged yet
19 | - -closed
20 | actions:
21 | label:
22 | remove: ["has conflicts"]
23 |
24 | - name: add core reviewer
25 | conditions:
26 | # number of review approvals
27 | - "#approved-reviews-by<2"
28 | actions:
29 | request_reviews:
30 | users:
31 | - Borda
32 |
--------------------------------------------------------------------------------
/.github/scripts/find-unused-caches.py:
--------------------------------------------------------------------------------
1 | """Script for filtering unused caches."""
2 |
3 | import os
4 | from datetime import timedelta
5 |
6 |
7 | def fetch_all_caches(repository: str, token: str, per_page: int = 100, max_pages: int = 100) -> list[dict]:
8 | """Fetch list of al caches from a given repository.
9 |
10 | Args:
11 | repository: user / repo-name
12 | token: authentication token for GH API calls
13 | per_page: number of items per listing page
14 | max_pages: max number of listing pages
15 |
16 | """
17 | import requests
18 | from pandas import Timestamp, to_datetime
19 |
20 | # Initialize variables for pagination
21 | all_caches = []
22 |
23 | for page in range(max_pages):
24 | # Get a page of caches for the repository
25 | url = f"https://api.github.com/repos/{repository}/actions/caches?page={page + 1}&per_page={per_page}"
26 | headers = {"Authorization": f"token {token}"}
27 | response = requests.get(url, headers=headers, timeout=10).json()
28 | if "total_count" not in response:
29 | raise RuntimeError(response.get("message"))
30 | print(f"fetching page... {page} with {per_page} items of expected {response.get('total_count')}")
31 | caches = response.get("actions_caches", [])
32 |
33 | # Append the caches from this page to the overall list
34 | all_caches.extend(caches)
35 |
36 | # Check if there are more pages to retrieve
37 | if len(caches) < per_page:
38 | break
39 |
40 | # Iterate through all caches and list them
41 | if all_caches:
42 | current_date = Timestamp.now(tz="UTC")
43 | print(f"Caches {len(all_caches)} for {repository}:")
44 | for cache in all_caches:
45 | cache_key = cache["id"]
46 | created_at = to_datetime(cache["created_at"])
47 | last_used_at = to_datetime(cache["last_accessed_at"])
48 | cache["last_used_days"] = current_date - last_used_at
49 | age_used = cache["last_used_days"].round(freq="min")
50 | size = cache["size_in_bytes"] / (1024 * 1024)
51 | print(
52 | f"- Cache Key: {cache_key} |"
53 | f" Created At: {created_at.strftime('%Y-%m-%d %H:%M')} |"
54 | f" Used At: {last_used_at.strftime('%Y-%m-%d %H:%M')} [{age_used}] |"
55 | f" Size: {size:.2f} MB"
56 | )
57 | else:
58 | print("No caches found for the repository.")
59 | return all_caches
60 |
61 |
62 | def main(repository: str, token: str, age_days: float = 7, output_file: str = "unused-cashes.txt") -> None:
63 | """Entry point for CLI.
64 |
65 | Args:
66 | repository: GitHub repository name in form `/`
67 | token: authentication token for making API calls
68 | age_days: filter all caches older than this age set in days
69 | output_file: path to a file for dumping list of cache's Id
70 |
71 | """
72 | caches = fetch_all_caches(repository, token)
73 |
74 | delta_days = timedelta(days=age_days)
75 | old_caches = [str(cache["id"]) for cache in caches if cache["last_used_days"] > delta_days]
76 | print(f"found {len(old_caches)} old caches:\n {old_caches}")
77 |
78 | with open(output_file, "w", encoding="utf8") as fw:
79 | fw.write(os.linesep.join(old_caches))
80 |
81 |
82 | if __name__ == "__main__":
83 | from jsonargparse import auto_cli, set_parsing_settings
84 |
85 | set_parsing_settings(parse_optionals_as_positionals=True)
86 | auto_cli(main)
87 |
--------------------------------------------------------------------------------
/.github/scripts/find-unused-caches.txt:
--------------------------------------------------------------------------------
1 | # Requirements for running equally named script,
2 | # having it in extra file to prevent version discrepancy if hardcoded in workflow
3 |
4 | jsonargparse[signatures] >=4.38.0
5 | requests
6 | pandas
7 |
--------------------------------------------------------------------------------
/.github/stale.yml:
--------------------------------------------------------------------------------
1 | # https://github.com/marketplace/stale
2 |
3 | # Number of days of inactivity before an issue becomes stale
4 | daysUntilStale: 60
5 | # Number of days of inactivity before a stale issue is closed
6 | daysUntilClose: 14
7 | # Issues with these labels will never be considered stale
8 | exemptLabels:
9 | - pinned
10 | - security
11 | # Label to use when marking an issue as stale
12 | staleLabel: won't fix
13 | # Comment to post when marking an issue as stale. Set to `false` to disable
14 | markComment: >
15 | This issue has been automatically marked as stale because it has not had
16 | recent activity. It will be closed if no further activity occurs. Thank you
17 | for your contributions.
18 | # Comment to post when closing a stale issue. Set to `false` to disable
19 | closeComment: false
20 |
21 | # Set to true to ignore issues in a project (defaults to false)
22 | exemptProjects: true
23 | # Set to true to ignore issues in a milestone (defaults to false)
24 | exemptMilestones: true
25 | # Set to true to ignore issues with an assignee (defaults to false)
26 | exemptAssignees: true
27 |
--------------------------------------------------------------------------------
/.github/workflows/check-docs.yml:
--------------------------------------------------------------------------------
1 | name: Building docs
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | actions-ref:
7 | description: "Version of actions, normally the same as workflow"
8 | required: false
9 | type: string
10 | default: ""
11 | python-version:
12 | description: "Python version to use"
13 | default: "3.9"
14 | required: false
15 | type: string
16 | docs-dir:
17 | description: "Working directory to run make docs in"
18 | default: "./docs"
19 | required: false
20 | type: string
21 | timeout-minutes:
22 | description: "timeout-minutes for each job"
23 | default: 15
24 | required: false
25 | type: number
26 | requirements-file:
27 | description: "path to the requirement file"
28 | default: "requirements/docs.txt"
29 | required: false
30 | type: string
31 | env-vars:
32 | description: "custom environment variables in json format"
33 | required: false
34 | type: string
35 | default: |
36 | {
37 | "SPHINX_MOCK_REQUIREMENTS": 0,
38 | }
39 | make-target:
40 | description: "what test configs to run in json format"
41 | required: false
42 | type: string
43 | default: |
44 | ["html", "doctest", "linkcheck"]
45 | install-tex:
46 | description: "optional installing Texlive support - true|false"
47 | required: false
48 | type: string
49 | default: false
50 |
51 | defaults:
52 | run:
53 | shell: bash
54 |
55 | env:
56 | # just use CPU version since running on CPU machine
57 | TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"
58 | # default 0 means to keep for the maximum time
59 | KEEP_DAYS: 0
60 |
61 | jobs:
62 | make-docs:
63 | runs-on: ubuntu-22.04
64 | env: ${{ fromJSON(inputs.env-vars) }}
65 | strategy:
66 | fail-fast: false
67 | matrix:
68 | target: ${{ fromJSON(inputs.make-target) }}
69 | steps:
70 | - name: Checkout 🛎
71 | uses: actions/checkout@v4
72 | with:
73 | submodules: recursive
74 |
75 | - name: Set up Python 🐍 ${{ inputs.python-version }}
76 | uses: actions/setup-python@v5
77 | with:
78 | python-version: ${{ inputs.python-version }}
79 | cache: "pip"
80 |
81 | - name: Install pandoc & texlive
82 | if: ${{ inputs.install-tex == 'true' }}
83 | timeout-minutes: 10
84 | run: |
85 | sudo apt-get update --fix-missing
86 | sudo apt-get install -y \
87 | pandoc \
88 | texlive-latex-extra \
89 | dvipng \
90 | texlive-pictures \
91 | latexmk
92 | - name: Install dependencies
93 | timeout-minutes: 20
94 | run: |
95 | pip --version
96 | pip install -e . -U -r ${{ inputs.requirements-file }} -f ${TORCH_URL}
97 | pip list
98 |
99 | - name: Pull reusable 🤖 actions️
100 | if: ${{ inputs.actions-ref != '' }}
101 | uses: actions/checkout@v4
102 | with:
103 | ref: ${{ inputs.actions-ref }}
104 | path: .cicd
105 | repository: Lightning-AI/utilities
106 | - name: Print 🖨️ dependencies
107 | if: ${{ inputs.actions-ref != '' }}
108 | uses: ./.cicd/.github/actions/pip-list
109 |
110 | - name: Build documentation
111 | working-directory: ${{ inputs.docs-dir }}
112 | run: |
113 | make ${{ matrix.target }} \
114 | --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going"
115 |
116 | - name: Shorten keep artifact
117 | if: startsWith(github.event_name, 'pull_request')
118 | run: echo "KEEP_DAYS=7" >> $GITHUB_ENV
119 | - name: Upload built docs
120 | uses: actions/upload-artifact@v4
121 | with:
122 | name: docs-${{ matrix.target }}-${{ github.sha }}
123 | path: ${{ inputs.docs-dir }}/build/
124 | retention-days: ${{ env.KEEP_DAYS }}
125 | include-hidden-files: true
126 |
--------------------------------------------------------------------------------
/.github/workflows/check-md-links.yml:
--------------------------------------------------------------------------------
1 | name: Check Markdown links
2 | # https://github.com/gaurav-nelson/github-action-markdown-link-check
3 |
4 | on:
5 | workflow_call:
6 | inputs:
7 | base-branch:
8 | description: "Default branch name"
9 | required: true
10 | type: string
11 | default: "master"
12 | config-file:
13 | description: "Config file (JSON) used for markdown link checks"
14 | required: false
15 | type: string
16 | default: ""
17 |
18 | jobs:
19 | markdown-link-check:
20 | runs-on: ubuntu-24.04
21 | env:
22 | CONFIG_FILE: ${{ inputs.config-file }}
23 | MODIFIED_ONLY: "no"
24 | steps:
25 | - uses: actions/checkout@master
26 |
27 | - name: Create local version of config
28 | if: ${{ inputs.config-file == '' }}
29 | run: |
30 | echo '{
31 | "ignorePatterns": [
32 | {
33 | "pattern": "^https://github.com/${{ github.repository }}/pull/.*"
34 | }
35 | ],
36 | "httpHeaders": [
37 | {
38 | "urls": ["https://github.com/", "https://guides.github.com/", "https://help.github.com/", "https://docs.github.com/"],
39 | "headers": {
40 | "Accept-Encoding": "zstd, br, gzip, deflate"
41 | }
42 | }
43 | ]
44 | }' > 'markdown.links.config.json'
45 | echo "CONFIG_FILE=markdown.links.config.json" >> $GITHUB_ENV
46 | cat 'markdown.links.config.json'
47 | - name: Show config
48 | run: cat ${{ env.CONFIG_FILE }}
49 |
50 | - name: narrow scope for PR
51 | if: startsWith(github.event_name, 'pull_request')
52 | run: echo "MODIFIED_ONLY=yes" >> $GITHUB_ENV
53 | - name: Checking markdown link
54 | uses: gaurav-nelson/github-action-markdown-link-check@v1
55 | with:
56 | base-branch: ${{ inputs.base-branch }}
57 | use-quiet-mode: "yes"
58 | check-modified-files-only: ${{ env.MODIFIED_ONLY }}
59 | config-file: ${{ env.CONFIG_FILE }}
60 |
--------------------------------------------------------------------------------
/.github/workflows/check-package.yml:
--------------------------------------------------------------------------------
1 | name: Check package flow
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | actions-ref:
7 | description: "Version of actions, normally the same as workflow"
8 | required: true
9 | type: string
10 | artifact-name:
11 | description: "Unique name for collecting artifacts, it shall be unique for all workflows"
12 | required: true
13 | type: string
14 | install-extras:
15 | description: "optional extras which are needed to include also []"
16 | required: false
17 | type: string
18 | default: ""
19 | install-flags:
20 | description: "Additional pip install flags"
21 | required: false
22 | type: string
23 | default: "-f https://download.pytorch.org/whl/cpu/torch_stable.html"
24 | import-name:
25 | description: "Import name to test with after installation"
26 | required: true
27 | type: string
28 | custom-import-code:
29 | description: "additional import statement, need to be full python code"
30 | type: string
31 | required: false
32 | default: ""
33 | build-matrix:
34 | description: "what building configs in json format, expected keys are `os` and `python-version`"
35 | required: false
36 | type: string
37 | default: |
38 | {
39 | "os": ["ubuntu-latest"],
40 | }
41 | testing-matrix:
42 | description: "what test configs to run in json format, expected keys are `os` and `python-version`"
43 | required: false
44 | type: string
45 | # default operating systems should be pinned to specific versions instead of "-latest" for stability
46 | # https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners#supported-runners-and-hardware-resources
47 | default: |
48 | {
49 | "os": ["ubuntu-22.04", "macos-13", "windows-2022"],
50 | "python-version": ["3.9", "3.13"]
51 | }
52 | env-vars:
53 | description: "custom environment variables in json format"
54 | required: false
55 | type: string
56 | default: |
57 | {
58 | "SAMPLE_ENV_VARIABLE": 1,
59 | }
60 |
61 | defaults:
62 | run:
63 | shell: bash
64 |
65 | jobs:
66 | pkg-build:
67 | runs-on: ${{ matrix.os }}
68 | env: ${{ fromJSON(inputs.env-vars) }}
69 | strategy:
70 | fail-fast: false
71 | matrix: ${{ fromJSON(inputs.build-matrix) }}
72 | steps:
73 | - name: Checkout 🛎️
74 | uses: actions/checkout@v4
75 | with:
76 | fetch-depth: 0 # checkout entire history for all branches (required when using scm-based versioning)
77 | submodules: recursive
78 | - name: Set up Python 🐍
79 | uses: actions/setup-python@v5
80 | with:
81 | python-version: ${{ matrix.python-version || '3.x' }}
82 |
83 | - name: Pull reusable 🤖 actions️
84 | uses: actions/checkout@v4
85 | with:
86 | ref: ${{ inputs.actions-ref }}
87 | path: .cicd
88 | repository: Lightning-AI/utilities
89 | - name: Prepare build env.
90 | run: |
91 | pip install -q -r ./.cicd/requirements/gha-package.txt
92 | pip list
93 | - name: Create package 📦
94 | uses: ./.cicd/.github/actions/pkg-create
95 | - name: Upload 📤 packages
96 | if: ${{ inputs.artifact-name != '' }}
97 | uses: actions/upload-artifact@v4
98 | with:
99 | name: ${{ inputs.artifact-name }}-build-${{ strategy.job-index }}
100 | path: dist
101 |
102 | merge-artifacts:
103 | needs: pkg-build
104 | runs-on: ubuntu-latest
105 | steps:
106 | - name: Pull reusable 🤖 actions️
107 | uses: actions/checkout@v4
108 | with:
109 | ref: ${{ inputs.actions-ref }}
110 | path: .cicd
111 | repository: Lightning-AI/utilities
112 | - name: Prepare build env.
113 | run: |
114 | pip install -q -r ./.cicd/requirements/gha-package.txt
115 | pip list
116 |
117 | - name: Download 📥
118 | uses: actions/download-artifact@v4
119 | with:
120 | # download all build artifacts
121 | pattern: ${{ inputs.artifact-name }}-build-*
122 | merge-multiple: true
123 | path: dist
124 | - name: Brief look
125 | run: |
126 | ls -lh dist/
127 | twine check dist/*
128 | - name: Upload 📤
129 | uses: actions/upload-artifact@v4
130 | with:
131 | name: ${{ inputs.artifact-name }}
132 | path: dist
133 |
134 | pkg-check:
135 | needs: merge-artifacts
136 | runs-on: ${{ matrix.os }}
137 | env: ${{ fromJSON(inputs.env-vars) }}
138 | strategy:
139 | fail-fast: false
140 | matrix: ${{ fromJSON(inputs.testing-matrix) }}
141 | steps:
142 | - name: Checkout 🛎️
143 | uses: actions/checkout@v4
144 | with:
145 | submodules: recursive
146 | - name: Set up Python 🐍 ${{ matrix.python-version }}
147 | uses: actions/setup-python@v5
148 | with:
149 | python-version: ${{ matrix.python-version || '3.x' }}
150 |
151 | - name: Pull reusable 🤖 actions️
152 | uses: actions/checkout@v4
153 | with:
154 | ref: ${{ inputs.actions-ref }}
155 | path: .cicd
156 | repository: Lightning-AI/utilities
157 | - name: Download 📥 all packages
158 | if: ${{ inputs.artifact-name != '' }}
159 | uses: actions/download-artifact@v4
160 | with:
161 | name: ${{ inputs.artifact-name }}
162 | path: pypi
163 | - name: Installing package 📦 as Archive
164 | timeout-minutes: 10
165 | uses: ./.cicd/.github/actions/pkg-install
166 | with:
167 | install-from: "archive"
168 | pkg-extras: ${{ inputs.install-extras }}
169 | pip-flags: ${{ inputs.install-flags }}
170 | import-name: ${{ inputs.import-name }}
171 | custom-import-code: ${{ inputs.custom-import-code }}
172 | - name: Installing package 📦 as Wheel
173 | timeout-minutes: 10
174 | uses: ./.cicd/.github/actions/pkg-install
175 | with:
176 | install-from: "wheel"
177 | pkg-extras: ${{ inputs.install-extras }}
178 | pip-flags: ${{ inputs.install-flags }}
179 | import-name: ${{ inputs.import-name }}
180 | custom-import-code: ${{ inputs.custom-import-code }}
181 |
182 | # TODO: add run doctests
183 |
184 | pkg-guardian:
185 | runs-on: ubuntu-latest
186 | needs: pkg-check
187 | if: always()
188 | steps:
189 | - run: echo "${{ needs.pkg-check.result }}"
190 | - name: failing...
191 | if: needs.pkg-check.result == 'failure'
192 | run: exit 1
193 | - name: cancelled or skipped...
194 | if: contains(fromJSON('["cancelled", "skipped"]'), needs.pkg-check.result)
195 | timeout-minutes: 1
196 | run: sleep 90
197 |
--------------------------------------------------------------------------------
/.github/workflows/check-precommit.yml:
--------------------------------------------------------------------------------
1 | name: Check formatting flow
2 |
3 | on:
4 | workflow_call:
5 | secrets:
6 | github-token:
7 | description: "if provided an user GH's token, it will push update; requires `push` event"
8 | required: false
9 | inputs:
10 | python-version:
11 | description: "Python version to use"
12 | default: "3.9"
13 | required: false
14 | type: string
15 | use-cache:
16 | description: "enable using GH caching for performance boost"
17 | type: boolean
18 | required: false
19 | default: true
20 | push-fixes:
21 | description: "if provided an user GH's token, it will push update; requires `push` event"
22 | type: boolean
23 | required: false
24 | default: false
25 |
26 | defaults:
27 | run:
28 | shell: bash
29 |
30 | jobs:
31 | pre-commit:
32 | runs-on: ubuntu-22.04
33 | steps:
34 | - name: Checkout 🛎️
35 | uses: actions/checkout@v4
36 | with:
37 | fetch-depth: 0
38 | submodules: recursive
39 | token: ${{ secrets.github-token || github.token }}
40 |
41 | - name: Set up Python 🐍
42 | uses: actions/setup-python@v5
43 | with:
44 | python-version: ${{ inputs.python-version }}
45 |
46 | - name: Cache 💽 pre-commit
47 | if: ${{ inputs.use-cache == true }}
48 | uses: actions/cache@v4
49 | with:
50 | path: ~/.cache/pre-commit
51 | key: pre-commit|py${{ inputs.python-version }}|${{ hashFiles('.pre-commit-config.yaml') }}
52 |
53 | - name: Run pre-commit 🤖
54 | id: precommit
55 | run: |
56 | pip install -q pre-commit
57 | pre-commit run --all-files
58 |
59 | - name: Fixing Pull Request ↩️
60 | if: always() && inputs.push-fixes == true && steps.precommit.outcome == 'failure'
61 | uses: actions-js/push@v1.5
62 | with:
63 | github_token: ${{ secrets.github-token || github.token }}
64 | message: "pre-commit: running and fixing..."
65 | branch: ${{ github.ref_name }}
66 |
--------------------------------------------------------------------------------
/.github/workflows/check-schema.yml:
--------------------------------------------------------------------------------
1 | name: Check schema flow
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | actions-ref:
7 | description: "Version of actions, normally the same as workflow"
8 | required: false
9 | type: string
10 | default: ""
11 | azure-dir:
12 | description: "Directory containing Azure Pipelines config files. Provide an empty string to skip checking on Azure Pipelines files."
13 | default: ".azure/"
14 | required: false
15 | type: string
16 | azure-schema-version:
17 | description: "Version of Azure Pipelines schema to use. Provide an empty string to skip checking on Azure Pipelines files."
18 | default: "v1.208.0"
19 | required: false
20 | type: string
21 |
22 | defaults:
23 | run:
24 | shell: bash
25 |
26 | jobs:
27 | schema:
28 | runs-on: ubuntu-24.04
29 | steps:
30 | - name: Checkout 🛎
31 | uses: actions/checkout@v4
32 | with:
33 | submodules: recursive
34 | - name: Set up Python
35 | uses: actions/setup-python@v5
36 | with:
37 | python-version: "3.10"
38 |
39 | # if actions version is given install defined versions
40 | - name: "[optional] Pull reusable 🤖 actions"
41 | if: inputs.actions-ref != ''
42 | uses: actions/checkout@v4
43 | with:
44 | ref: ${{ inputs.actions-ref }}
45 | path: .cicd
46 | repository: Lightning-AI/utilities
47 | - name: "[optional] Install recommended dependencies"
48 | if: inputs.actions-ref != ''
49 | timeout-minutes: 5
50 | run: |
51 | pip install -r ./.cicd/requirements/gha-schema.txt
52 | pip list | grep "check-jsonschema"
53 | # otherwise fall back to using the latest
54 | - name: "[default] Install recommended dependencies"
55 | if: inputs.actions-ref == ''
56 | timeout-minutes: 5
57 | run: |
58 | pip install -q check-jsonschema
59 | pip list | grep "check-jsonschema"
60 |
61 | - name: Scan repo
62 | id: folders
63 | run: python -c "import os; print('gh_actions=' + str(int(os.path.isdir('.github/actions'))))" >> $GITHUB_OUTPUT
64 |
65 | # https://github.com/SchemaStore/schemastore/blob/master/src/schemas/json/github-workflow.json
66 | - name: GitHub Actions - workflow
67 | run: |
68 | files=$(find .github/workflows -name '*.yml' -or -name '*.yaml' -not -name '__*')
69 | for f in $files; do
70 | echo $f;
71 | check-jsonschema -v $f --builtin-schema "github-workflows";
72 | done
73 |
74 | # https://github.com/SchemaStore/schemastore/blob/master/src/schemas/json/github-action.json
75 | - name: GitHub Actions - action
76 | if: steps.folders.outputs.gh_actions == '1'
77 | run: |
78 | files=$(find .github/actions -name '*.yml' -or -name '*.yaml')
79 | for f in $files; do
80 | echo $f;
81 | check-jsonschema -v $f --builtin-schema "github-actions";
82 | done
83 |
84 | # https://github.com/microsoft/azure-pipelines-vscode/blob/main/service-schema.json
85 | - name: Azure Pipelines
86 | if: ${{ inputs.azure-dir != '' }}
87 | env:
88 | SCHEMA_FILE: https://raw.githubusercontent.com/microsoft/azure-pipelines-vscode/${{ inputs.azure-schema-version }}/service-schema.json
89 | run: |
90 | files=$(find ${{ inputs.azure-dir }} -name '*.yml' -or -name '*.yaml')
91 | for f in $files; do
92 | echo $f;
93 | check-jsonschema -v $f --schemafile "$SCHEMA_FILE" --regex-variant="nonunicode";
94 | done
95 |
--------------------------------------------------------------------------------
/.github/workflows/check-typing.yml:
--------------------------------------------------------------------------------
1 | name: Check formatting flow
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | actions-ref:
7 | description: "Version of actions, normally the same as workflow"
8 | required: false
9 | type: string
10 | default: ""
11 | python-version:
12 | description: "Python version to use"
13 | default: "3.10"
14 | required: false
15 | type: string
16 | source-dir:
17 | description: "Source directory to check"
18 | default: "src/"
19 | required: false
20 | type: string
21 | extra-typing:
22 | description: "Package extra to be installed for type checks + include mypy"
23 | default: "test"
24 | required: false
25 | type: string
26 |
27 | defaults:
28 | run:
29 | shell: bash
30 |
31 | jobs:
32 | mypy:
33 | runs-on: ubuntu-24.04
34 | steps:
35 | - name: Checkout 🛎️
36 | uses: actions/checkout@v4
37 | with:
38 | submodules: recursive
39 |
40 | - name: Set up Python 🐍 ${{ inputs.python-version }}
41 | uses: actions/setup-python@v5
42 | with:
43 | python-version: ${{ inputs.python-version }}
44 |
45 | - name: Install dependencies
46 | timeout-minutes: 20
47 | run: |
48 | # don't use --upgrade to respect the version installed via setup.py
49 | pip install -e '.[${{ inputs.extra-typing }}]' mypy \
50 | --extra-index-url https://download.pytorch.org/whl/cpu/torch_stable.html
51 | pip list
52 |
53 | - name: Pull reusable 🤖 actions️
54 | if: ${{ inputs.actions-ref != '' }}
55 | uses: actions/checkout@v4
56 | with:
57 | ref: ${{ inputs.actions-ref }}
58 | path: .cicd
59 | repository: Lightning-AI/utilities
60 | - name: Print 🖨️ dependencies
61 | if: ${{ inputs.actions-ref != '' }}
62 | uses: ./.cicd/.github/actions/pip-list
63 | with:
64 | unfold: true
65 |
66 | # see: https://github.com/python/mypy/issues/10600#issuecomment-857351152
67 | - name: init mypy
68 | continue-on-error: true
69 | run: |
70 | mkdir -p .mypy_cache
71 | mypy --install-types --non-interactive .
72 |
73 | - name: Check typing
74 | # mypy uses the config file found in the following order:
75 | # 1. mypy.ini
76 | # 2. pyproject.toml
77 | # 3. setup.cfg
78 | # 4. $XDG_CONFIG_HOME/mypy/config
79 | # 5. ~/.config/mypy/config
80 | # 6. ~/.mypy.ini
81 | # https://mypy.readthedocs.io/en/stable/config_file.html
82 | run: mypy
83 |
84 | - name: suggest ignores
85 | if: failure()
86 | env:
87 | SOURCE_DIR: ${{ inputs.source-dir }}
88 | run: |
89 | mypy --no-error-summary 2>&1 \
90 | | tr ':' ' ' \
91 | | awk '{print $1}' \
92 | | sort \
93 | | uniq \
94 | | sed 's/\.py//g; s|${SOURCE_DIR}||g; s|\/__init__||g; s|\/|\.|g' \
95 | | xargs -I {} echo '"{}",' \
96 | || true
97 |
--------------------------------------------------------------------------------
/.github/workflows/ci-cli.yml:
--------------------------------------------------------------------------------
1 | name: Test CLI
2 |
3 | on:
4 | push:
5 | branches: [main, "release/*"]
6 | pull_request:
7 | branches: [main, "release/*"]
8 |
9 | defaults:
10 | run:
11 | shell: bash
12 |
13 | jobs:
14 | test-cli:
15 | runs-on: ${{ matrix.os }}
16 | strategy:
17 | fail-fast: false
18 | matrix:
19 | os: ["ubuntu-latest", "macos-latest", "windows-latest"]
20 | python-version: ["3.10"]
21 | timeout-minutes: 10
22 | steps:
23 | - name: Checkout 🛎
24 | uses: actions/checkout@v4
25 | - name: Set up Python 🐍 ${{ matrix.python-version }}
26 | uses: actions/setup-python@v5
27 | with:
28 | python-version: ${{ matrix.python-version }}
29 |
30 | - name: install package
31 | run: |
32 | pip install -e '.[cli]'
33 | pip list
34 |
35 | - name: run CLI
36 | working-directory: ./requirements
37 | run: |
38 | python -m lightning_utilities.cli version
39 | python -m lightning_utilities.cli --help
40 | python -m lightning_utilities.cli requirements set-oldest --req_files="cli.txt"
41 | python -m lightning_utilities.cli requirements set-oldest --req_files='["cli.txt", "docs.txt"]'
42 |
43 | cli-guardian:
44 | runs-on: ubuntu-latest
45 | needs: test-cli
46 | if: always()
47 | steps:
48 | - run: echo "${{ needs.test-cli.result }}"
49 | - name: failing...
50 | if: needs.test-cli.result == 'failure'
51 | run: exit 1
52 | - name: cancelled or skipped...
53 | if: contains(fromJSON('["cancelled", "skipped"]'), needs.test-cli.result)
54 | timeout-minutes: 1
55 | run: sleep 90
56 |
--------------------------------------------------------------------------------
/.github/workflows/ci-rtfd.yml:
--------------------------------------------------------------------------------
1 | name: RTFD Preview
2 | on:
3 | pull_request_target:
4 | types:
5 | - opened
6 |
7 | permissions:
8 | pull-requests: write
9 |
10 | jobs:
11 | documentation-links:
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: readthedocs/actions/preview@v1
15 | with:
16 | project-slug: "lit-utilities"
17 |
--------------------------------------------------------------------------------
/.github/workflows/ci-scripts.yml:
--------------------------------------------------------------------------------
1 | name: Test scripts
2 |
3 | on:
4 | push:
5 | branches: [main, "release/*"]
6 | pull_request:
7 | branches: [main, "release/*"]
8 |
9 | defaults:
10 | run:
11 | shell: bash
12 |
13 | jobs:
14 | test-scripts:
15 | runs-on: ${{ matrix.os }}
16 | strategy:
17 | fail-fast: false
18 | matrix:
19 | os: ["ubuntu-latest", "macos-latest", "windows-latest"]
20 | python-version: ["3.10"]
21 | timeout-minutes: 15
22 | steps:
23 | - name: Checkout 🛎
24 | uses: actions/checkout@v4
25 | - name: Set up Python 🐍 ${{ matrix.python-version }}
26 | uses: actions/setup-python@v5
27 | with:
28 | python-version: ${{ matrix.python-version }}
29 | cache: "pip"
30 |
31 | - name: Install dependencies
32 | timeout-minutes: 5
33 | run: |
34 | pip install -r requirements/_tests.txt
35 | pip --version
36 | pip list
37 |
38 | - name: test Scripts
39 | working-directory: ./scripts
40 | run: pytest . -v
41 |
42 | standalone-run:
43 | runs-on: "ubuntu-22.04"
44 | timeout-minutes: 20
45 | env:
46 | TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"
47 | steps:
48 | - name: Checkout 🛎
49 | uses: actions/checkout@v4
50 | - name: Set up Python 🐍 ${{ matrix.python-version }}
51 | uses: actions/setup-python@v5
52 | with:
53 | python-version: "3.10"
54 | cache: "pip"
55 | - name: Install dependencies
56 | timeout-minutes: 20
57 | run: |
58 | set -e
59 | pip install -e . -U -r requirements/_tests.txt -f $TORCH_URL
60 | pip --version
61 | pip list
62 |
63 | - name: Run standalone script
64 | run: bash ./scripts/run_standalone_tests.sh "tests"
65 | env:
66 | COVERAGE_SOURCE: "lightning_utilities"
67 |
68 | scripts-guardian:
69 | runs-on: ubuntu-latest
70 | needs: [test-scripts, standalone-run]
71 | if: always()
72 | steps:
73 | - run: echo "${{ needs.test-scripts.result }}"
74 | - name: failing...
75 | if: needs.test-scripts.result == 'failure'
76 | run: exit 1
77 | - name: cancelled or skipped...
78 | if: contains(fromJSON('["cancelled", "skipped"]'), needs.test-scripts.result)
79 | timeout-minutes: 1
80 | run: sleep 90
81 |
--------------------------------------------------------------------------------
/.github/workflows/ci-testing.yml:
--------------------------------------------------------------------------------
1 | name: UnitTests
2 |
3 | on:
4 | push:
5 | branches: [main, "release/*"]
6 | pull_request:
7 | branches: [main, "release/*"]
8 |
9 | defaults:
10 | run:
11 | shell: bash
12 |
13 | jobs:
14 | pytester:
15 | runs-on: ${{ matrix.os }}
16 | strategy:
17 | fail-fast: false
18 | matrix:
19 | os: ["ubuntu-22.04", "macos-13", "windows-2022"]
20 | python-version: ["3.9", "3.11", "3.13"]
21 | requires: ["oldest", "latest"]
22 | exclude:
23 | - { requires: "oldest", python-version: "3.13" }
24 | timeout-minutes: 35
25 | env:
26 | TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"
27 | steps:
28 | - name: Checkout 🛎
29 | uses: actions/checkout@v4
30 | with:
31 | submodules: recursive
32 | - name: Set up Python 🐍 ${{ matrix.python-version }}
33 | uses: actions/setup-python@v5
34 | with:
35 | python-version: ${{ matrix.python-version }}
36 |
37 | - name: Set oldest dependencies
38 | if: matrix.requires == 'oldest'
39 | timeout-minutes: 20
40 | run: |
41 | pip install -e '.[cli]'
42 | python -m lightning_utilities.cli requirements set-oldest
43 |
44 | - name: Complex 💽 caching
45 | uses: ./.github/actions/cache
46 | with:
47 | python-version: ${{ matrix.python-version }}
48 |
49 | - name: Install dependencies
50 | timeout-minutes: 20
51 | run: |
52 | set -e
53 | pip install -e . -U -r requirements/_tests.txt -f $TORCH_URL
54 | pip --version
55 | pip list
56 |
57 | - name: Print 🖨️ dependencies
58 | uses: ./.github/actions/pip-list
59 |
60 | - name: Unittest and coverage
61 | uses: ./.github/actions/unittesting
62 | with:
63 | python-version: ${{ matrix.python-version }}
64 | dirs: "unittests"
65 | pkg-name: "lightning_utilities"
66 | pytest-args: "--timeout=120"
67 |
68 | - name: Upload coverage to Codecov
69 | uses: codecov/codecov-action@v5.4.3
70 | continue-on-error: true
71 | with:
72 | token: ${{ secrets.CODECOV_TOKEN }}
73 | file: ./coverage.xml
74 | flags: unittests
75 | env_vars: OS,PYTHON
76 | name: codecov-umbrella
77 | fail_ci_if_error: false
78 |
79 | - name: test CI scripts
80 | working-directory: ./tests
81 | run: python -m pytest scripts --durations=50 --timeout=120
82 |
83 | testing-guardian:
84 | runs-on: ubuntu-latest
85 | needs: pytester
86 | if: always()
87 | steps:
88 | - run: echo "${{ needs.pytester.result }}"
89 | - name: failing...
90 | if: needs.pytester.result == 'failure'
91 | run: exit 1
92 | - name: cancelled or skipped...
93 | if: contains(fromJSON('["cancelled", "skipped"]'), needs.pytester.result)
94 | timeout-minutes: 1
95 | run: sleep 90
96 |
--------------------------------------------------------------------------------
/.github/workflows/ci-use-checks.yaml:
--------------------------------------------------------------------------------
1 | name: Apply checks
2 |
3 | on:
4 | push:
5 | branches: [main, "release/*"]
6 | pull_request:
7 | branches: [main, "release/*"]
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
11 | cancel-in-progress: ${{ ! (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
12 |
13 | jobs:
14 | check-typing:
15 | uses: ./.github/workflows/check-typing.yml
16 | with:
17 | actions-ref: ${{ github.sha }} # use local version
18 | extra-typing: "typing"
19 |
20 | check-precommit:
21 | uses: ./.github/workflows/check-precommit.yml
22 |
23 | check-schema-latest:
24 | uses: ./.github/workflows/check-schema.yml
25 | with:
26 | azure-dir: ""
27 |
28 | check-schema-fixed:
29 | uses: ./.github/workflows/check-schema.yml
30 | with:
31 | actions-ref: ${{ github.sha }} # use local version
32 | azure-dir: ""
33 |
34 | check-schema:
35 | runs-on: ubuntu-latest
36 | # just aggregation of the previous two jobs
37 | needs: ["check-schema-latest", "check-schema-fixed"]
38 | steps:
39 | - run: echo "done"
40 |
41 | check-package:
42 | uses: ./.github/workflows/check-package.yml
43 | with:
44 | actions-ref: ${{ github.sha }} # use local version
45 | artifact-name: dist-packages-${{ github.sha }}
46 | import-name: "lightning_utilities"
47 | build-matrix: |
48 | {
49 | "os": ["ubuntu-22.04"],
50 | "python-version": ["3.10"]
51 | }
52 | testing-matrix: |
53 | {
54 | "os": ["ubuntu-22.04", "macos-14", "windows-2022"],
55 | "python-version": ["3.9", "3.13"]
56 | }
57 |
58 | check-package-extras:
59 | uses: ./.github/workflows/check-package.yml
60 | with:
61 | actions-ref: ${{ github.sha }} # use local version
62 | artifact-name: dist-packages-extras-${{ github.sha }}
63 | import-name: "lightning_utilities"
64 | install-extras: "[cli]"
65 | # todo: when we have a module with dependence on extra, replace it
66 | # tried to import `lightning_utilities.cli.__main__` but told me it does not exits
67 | custom-import-code: "import jsonargparse"
68 | build-matrix: |
69 | {
70 | "os": ["ubuntu-latest", "macos-latest"],
71 | "python-version": ["3.10"]
72 | }
73 | testing-matrix: |
74 | {
75 | "os": ["ubuntu-latest", "macos-latest", "windows-latest"],
76 | "python-version": ["3.10"]
77 | }
78 |
79 | check-docs:
80 | uses: ./.github/workflows/check-docs.yml
81 | with:
82 | actions-ref: ${{ github.sha }} # use local version
83 | requirements-file: "requirements/_docs.txt"
84 | install-tex: true
85 |
86 | check-md-links-default:
87 | uses: ./.github/workflows/check-md-links.yml
88 | with:
89 | base-branch: main
90 |
91 | check-md-links-w-config:
92 | uses: ./.github/workflows/check-md-links.yml
93 | with:
94 | base-branch: main
95 | config-file: ".github/markdown.links.config.json"
96 |
--------------------------------------------------------------------------------
/.github/workflows/cleanup-caches.yml:
--------------------------------------------------------------------------------
1 | name: Cleaning caches
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | scripts-ref:
7 | description: "Version of script, normally the same as workflow"
8 | required: true
9 | type: string
10 | #gh-token:
11 | # description: 'PAT which is authorized to delete caches for given repo'
12 | # required: true
13 | # type: string
14 | dry-run:
15 | description: "allow just listing and not delete any, options yes|no"
16 | required: true
17 | type: string
18 | age-days:
19 | description: "setting the age of caches in days to be dropped"
20 | required: true
21 | type: number
22 | default: 7
23 | pattern:
24 | description: "string to grep cache keys with"
25 | required: false
26 | type: string
27 | default: ""
28 |
29 | defaults:
30 | run:
31 | shell: bash
32 |
33 | jobs:
34 | cleanup-caches:
35 | runs-on: ubuntu-latest
36 | timeout-minutes: 15
37 | env:
38 | GH_TOKEN: ${{ github.token }}
39 | AGE_DAYS: ${{ inputs.age-days }}
40 | steps:
41 | - name: Checkout Code
42 | uses: actions/checkout@v4
43 |
44 | - name: Pull reusable 🤖 actions️
45 | uses: actions/checkout@v4
46 | with:
47 | ref: ${{ inputs.scripts-ref }}
48 | path: .cicd
49 | repository: Lightning-AI/utilities
50 | - name: install requirements
51 | timeout-minutes: 20
52 | run: pip install -U -r ./.cicd/.github/scripts/find-unused-caches.txt
53 |
54 | - name: List and Filer 🔍 caches
55 | run: |
56 | python ./.cicd/.github/scripts/find-unused-caches.py \
57 | --repository="${{ github.repository }}" --token=${GH_TOKEN} --age_days=${AGE_DAYS}
58 | cat unused-cashes.txt
59 |
60 | - name: Delete 🗑️ caches
61 | if: inputs.dry-run != 'true'
62 | run: |
63 | # Use a while loop to read each line from the file
64 | while read -r line || [ -n "$line" ]; do
65 | echo "$line";
66 | # delete each cache based on file...
67 | gh api --method DELETE -H "Accept: application/vnd.github+json" /repos/${{ github.repository }}/actions/caches/ $line;
68 | done < "unused-cashes.txt"
69 |
--------------------------------------------------------------------------------
/.github/workflows/cron-clear-cache.yml:
--------------------------------------------------------------------------------
1 | name: Clear cache weekly
2 |
3 | on:
4 | schedule:
5 | # on Sunday's night 2am
6 | - cron: "0 2 * * 0"
7 | pull_request:
8 | paths:
9 | - ".github/scripts/find-unused-caches.py"
10 | - ".github/scripts/find-unused-caches.txt"
11 | - ".github/workflows/cleanup-caches.yml"
12 | - ".github/workflows/cron-clear-cache.yml"
13 | workflow_dispatch:
14 | inputs:
15 | pattern:
16 | description: "pattern for cleaning cache"
17 | default: "pip"
18 | required: false
19 | type: string
20 | age-days:
21 | description: "setting the age of caches in days to be dropped"
22 | required: true
23 | type: number
24 | default: 7
25 |
26 | jobs:
27 | drop-unused-caches:
28 | uses: ./.github/workflows/cleanup-caches.yml
29 | with:
30 | scripts-ref: ${{ github.sha }} # use local version
31 | dry-run: ${{ github.event_name == 'pull_request' }}
32 | # ise input if set of default...
33 | pattern: ${{ inputs.pattern || 'pip|conda' }}
34 | age-days: ${{ fromJSON(inputs.age-days) || 2 }}
35 |
--------------------------------------------------------------------------------
/.github/workflows/deploy-docs.yml:
--------------------------------------------------------------------------------
1 | name: "Deploy Docs"
2 | on:
3 | push:
4 | branches: [main]
5 | pull_request:
6 | branches: [main]
7 |
8 | defaults:
9 | run:
10 | shell: bash
11 |
12 | jobs:
13 | # https://github.com/marketplace/actions/deploy-to-github-pages
14 | build-docs-deploy:
15 | runs-on: ubuntu-latest
16 | env:
17 | TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"
18 | steps:
19 | - name: Checkout 🛎️
20 | uses: actions/checkout@v4
21 | # If you're using actions/checkout@v4 you must set persist-credentials to false in most cases for the deployment to work correctly.
22 | with:
23 | persist-credentials: false
24 | submodules: recursive
25 | - name: Set up Python 🐍
26 | uses: actions/setup-python@v5
27 | with:
28 | python-version: "3.10"
29 | cache: "pip"
30 |
31 | # Note: This uses an internal pip API and may not always work
32 | # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
33 | - name: Cache 💽 pip
34 | uses: actions/cache@v4
35 | with:
36 | path: ~/.cache/pip
37 | key: pip-${{ hashFiles('requirements/*.txt') }}
38 | restore-keys: pip-
39 |
40 | #- name: Install texlive
41 | # timeout-minutes: 20
42 | # run: |
43 | # # install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux
44 | # sudo apt-get update --fix-missing
45 | # sudo apt-get install -y texlive-latex-extra dvipng texlive-pictures
46 | # shell: bash
47 |
48 | - name: Install dependencies
49 | timeout-minutes: 20
50 | run: |
51 | pip --version
52 | pip install -e . -U -q -r requirements/_docs.txt -f ${TORCH_URL}
53 | pip list
54 | shell: bash
55 |
56 | - name: Make Documentation
57 | working-directory: ./docs
58 | run: make html --jobs 2
59 |
60 | - name: Deploy 🚀
61 | uses: JamesIves/github-pages-deploy-action@v4.7.3
62 | if: github.ref == 'refs/heads/main'
63 | with:
64 | token: ${{ secrets.GITHUB_TOKEN }}
65 | branch: gh-pages # The branch the action should deploy to.
66 | folder: docs/build/html # The folder the action should deploy.
67 | clean: true # Automatically remove deleted files from the deploy branch
68 | target-folder: docs # If you'd like to push the contents of the deployment folder into a specific directory
69 | single-commit: true # you'd prefer to have a single commit on the deployment branch instead of full history
70 |
--------------------------------------------------------------------------------
/.github/workflows/label-pr.yml:
--------------------------------------------------------------------------------
1 | name: Label Pull Requests
2 | on: [pull_request_target]
3 |
4 | jobs:
5 | triage:
6 | permissions:
7 | contents: read
8 | pull-requests: write
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/labeler@v5
12 | with:
13 | repo-token: "${{ secrets.GITHUB_TOKEN }}"
14 | configuration-path: .github/labeler-config.yml
15 |
--------------------------------------------------------------------------------
/.github/workflows/release-pypi.yml:
--------------------------------------------------------------------------------
1 | name: PyPI Release
2 |
3 | # https://help.github.com/en/actions/reference/events-that-trigger-workflows
4 | on: # Trigger the workflow on push or pull request, but only for the master branch
5 | push:
6 | branches: [main, "release/*"]
7 | tags: ["v?[0-9]+.[0-9]+.[0-9]+"]
8 | release:
9 | types: [published]
10 | pull_request:
11 | branches: [main]
12 |
13 | defaults:
14 | run:
15 | shell: bash
16 |
17 | jobs:
18 | # based on https://github.com/pypa/gh-action-pypi-publish
19 | build-package:
20 | runs-on: ubuntu-22.04
21 | timeout-minutes: 10
22 | steps:
23 | - name: Checkout 🛎️
24 | uses: actions/checkout@v4
25 | - name: Set up Python 🐍
26 | uses: actions/setup-python@v5
27 | with:
28 | python-version: "3.10"
29 | - name: Prepare build env.
30 | run: pip install -r ./requirements/gha-package.txt
31 | - name: Create 📦 package
32 | uses: ./.github/actions/pkg-create
33 | - name: Upload 📤 packages
34 | uses: actions/upload-artifact@v4
35 | with:
36 | name: pypi-packages-${{ github.sha }}
37 | path: dist
38 |
39 | upload-package:
40 | needs: build-package
41 | if: github.event_name == 'release'
42 | timeout-minutes: 5
43 | runs-on: ubuntu-latest
44 | steps:
45 | - name: Checkout 🛎️
46 | uses: actions/checkout@v4
47 | - name: Download 📥 artifact
48 | uses: actions/download-artifact@v4
49 | with:
50 | name: pypi-packages-${{ github.sha }}
51 | path: dist
52 | - name: local 🗃️ files
53 | run: ls -lh dist/
54 |
55 | - name: Upload to release
56 | uses: AButler/upload-release-assets@v3.0
57 | with:
58 | files: "dist/*"
59 | repo-token: ${{ secrets.GITHUB_TOKEN }}
60 |
61 | publish-package:
62 | needs: build-package
63 | if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
64 | runs-on: ubuntu-latest
65 | timeout-minutes: 5
66 | steps:
67 | - name: Checkout 🛎️
68 | uses: actions/checkout@v4
69 | with:
70 | submodules: recursive
71 | - name: Download 📥 artifact
72 | uses: actions/download-artifact@v4
73 | with:
74 | name: pypi-packages-${{ github.sha }}
75 | path: dist
76 | - name: local 🗃️ files
77 | run: ls -lh dist/
78 |
79 | # We do this, since failures on test.pypi aren't that bad
80 | - name: Publish to Test PyPI
81 | uses: pypa/gh-action-pypi-publish@v1.12.4
82 | with:
83 | user: __token__
84 | password: ${{ secrets.test_pypi_password }}
85 | repository-url: https://test.pypi.org/legacy/
86 | verbose: true
87 |
88 | - name: Publish distribution 📦 to PyPI
89 | uses: pypa/gh-action-pypi-publish@v1.12.4
90 | with:
91 | user: __token__
92 | password: ${{ secrets.pypi_password }}
93 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Sphinx documentation
58 | docs/_build/
59 | docs/source/fetched-s3-assets
60 | docs/source/api/
61 | docs/source/*.md
62 |
63 | # PyBuilder
64 | target/
65 |
66 | # Jupyter Notebook
67 | .ipynb_checkpoints
68 |
69 | # IPython
70 | profile_default/
71 | ipython_config.py
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
77 | __pypackages__/
78 |
79 | # Celery stuff
80 | celerybeat-schedule
81 | celerybeat.pid
82 |
83 | # SageMath parsed files
84 | *.sage.py
85 |
86 | # Environments
87 | .env
88 | .venv
89 | env/
90 | venv/
91 | ENV/
92 | env.bak/
93 | venv.bak/
94 |
95 | # Spyder project settings
96 | .spyderproject
97 | .spyproject
98 |
99 | # Rope project settings
100 | .ropeproject
101 |
102 | # mkdocs documentation
103 | /site
104 |
105 | # mypy
106 | .mypy_cache/
107 | .dmypy.json
108 | dmypy.json
109 |
110 | # Pyre type checker
111 | .pyre/
112 |
113 | # PyCharm
114 | .idea/
115 |
116 | # Lightning logs
117 | lightning_logs
118 | *.gz
119 | .DS_Store
120 | .*_submit.py
121 |
122 | Formatting
123 | .ruff_cache/
124 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 |
4 | ci:
5 | autofix_prs: true
6 | autoupdate_commit_msg: "[pre-commit.ci] pre-commit suggestions"
7 | autoupdate_schedule: monthly
8 | # submodules: true
9 |
10 | repos:
11 | - repo: https://github.com/pre-commit/pre-commit-hooks
12 | rev: v5.0.0
13 | hooks:
14 | - id: end-of-file-fixer
15 | - id: trailing-whitespace
16 | - id: check-case-conflict
17 | - id: check-yaml
18 | - id: check-toml
19 | - id: check-json
20 | - id: check-added-large-files
21 | - id: check-docstring-first
22 | - id: detect-private-key
23 |
24 | - repo: https://github.com/codespell-project/codespell
25 | rev: v2.4.1
26 | hooks:
27 | - id: codespell
28 | additional_dependencies: [tomli]
29 | #args: ["--write-changes"] # uncomment if you want to get automatic fixing
30 |
31 | - repo: https://github.com/PyCQA/docformatter
32 | rev: v1.7.7
33 | hooks:
34 | - id: docformatter
35 | additional_dependencies: [tomli]
36 | args: ["--in-place"]
37 |
38 | - repo: https://github.com/executablebooks/mdformat
39 | rev: 0.7.22
40 | hooks:
41 | - id: mdformat
42 | additional_dependencies:
43 | - mdformat-gfm
44 | - mdformat-black
45 | - mdformat_frontmatter
46 | args: ["--number"]
47 | exclude: CHANGELOG.md
48 |
49 | - repo: https://github.com/JoC0de/pre-commit-prettier
50 | rev: v3.5.3
51 | hooks:
52 | - id: prettier
53 | files: \.(json|yml|yaml|toml)
54 | # https://prettier.io/docs/en/options.html#print-width
55 | args: ["--print-width=120"]
56 |
57 | - repo: https://github.com/astral-sh/ruff-pre-commit
58 | rev: v0.11.12
59 | hooks:
60 | - id: ruff
61 | args: ["--fix"]
62 | - id: ruff-format
63 | - id: ruff
64 |
65 | - repo: https://github.com/sphinx-contrib/sphinx-lint
66 | rev: v1.0.0
67 | hooks:
68 | - id: sphinx-lint
69 |
70 | - repo: https://github.com/tox-dev/pyproject-fmt
71 | rev: v2.6.0
72 | hooks:
73 | - id: pyproject-fmt
74 | additional_dependencies: [tox]
75 | - repo: https://github.com/abravalheri/validate-pyproject
76 | rev: v0.24.1
77 | hooks:
78 | - id: validate-pyproject
79 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2018-2021 William Falcon
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include CHANGELOG.md
2 | recursive-include src py.typed
3 |
4 | graft requirements
5 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: test clean docs
2 |
3 | # to imitate SLURM set only single node
4 | export SLURM_LOCALID=0
5 | # assume you have installed need packages
6 | export SPHINX_MOCK_REQUIREMENTS=0
7 |
8 | test:
9 | pip install -q -r requirements/cli.txt -r requirements/_tests.txt
10 |
11 | # use this to run tests
12 | rm -rf _ckpt_*
13 | rm -rf ./lightning_logs
14 | python -m coverage run --source src/lightning_utilities -m pytest src/lightning_utilities tests -v
15 | python -m coverage report
16 |
17 | # specific file
18 | # python -m coverage run --source src/lightning_utilities -m pytest --flake8 --durations=0 -v -k
19 |
20 | docs: clean
21 | pip install -e . -q -r requirements/_docs.txt
22 | cd docs && $(MAKE) html
23 |
24 | clean:
25 | # clean all temp runs
26 | rm -rf .mypy_cache
27 | rm -rf .pytest_cache
28 | rm -rf .ruff_cache
29 | rm -rf build
30 | rm -rf dist
31 | rm -rf src/*.egg-info
32 | rm -rf ./docs/build
33 | rm -rf ./docs/source/**/generated
34 | rm -rf ./docs/source/api
35 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Lightning Utilities
2 |
3 | [](https://badge.fury.io/py/lightning-utilities)
4 | [](https://github.com/Lightning-AI/utilities/blob/master/LICENSE)
5 | [](https://pepy.tech/project/lightning-utilities)
6 | [](https://pypi.org/project/lightning-utilities/)
7 |
8 | [](https://github.com/Lightning-AI/utilities/actions/workflows/ci-testing.yml)
9 | [](https://github.com/Lightning-AI/utilities/actions/workflows/ci-use-checks.yaml)
10 | [](https://lit-utilities.readthedocs.io/en/latest/?badge=latest)
11 | [](https://results.pre-commit.ci/latest/github/Lightning-AI/utilities/main)
12 |
13 | __This repository covers the following use-cases:__
14 |
15 | 1. _Reusable GitHub workflows_
16 | 2. _Shared GitHub actions_
17 | 3. _General Python utilities in `lightning_utilities.core`_
18 | 4. _CLI `python -m lightning_utilities.cli --help`_
19 |
20 | ## 1. Reusable workflows
21 |
22 | __Usage:__
23 |
24 | ```yaml
25 | name: Check schema
26 |
27 | on: [push]
28 |
29 | jobs:
30 |
31 | check-schema:
32 | uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.5.0
33 | with:
34 | azure-dir: "" # skip Azure check
35 |
36 | check-code:
37 | uses: Lightning-AI/utilities/.github/workflows/check-code.yml@main
38 | with:
39 | actions-ref: main # normally you shall use the same version as the workflow
40 | ```
41 |
42 | See usage of other workflows in [.github/workflows/ci-use-checks.yaml](https://github.com/Lightning-AI/utilities/tree/main/.github/workflows/ci-use-checks.yaml).
43 |
44 | ## 2. Reusable composite actions
45 |
46 | See available composite actions [.github/actions/](https://github.com/Lightning-AI/utilities/tree/main/.github/actions).
47 |
48 | __Usage:__
49 |
50 | ```yaml
51 | name: Do something with cache
52 |
53 | on: [push]
54 |
55 | jobs:
56 | pytest:
57 | runs-on: ubuntu-24.04
58 | steps:
59 | - uses: actions/checkout@v3
60 | - uses: actions/setup-python@v4
61 | with:
62 | python-version: 3.9
63 | - uses: Lightning-AI/utilities/.github/actions/cache
64 | with:
65 | python-version: 3.9
66 | requires: oldest # or latest
67 | ```
68 |
69 | ## 3. General Python utilities `lightning_utilities.core`
70 |
71 |
72 | Installation
73 | From source:
74 |
75 | ```bash
76 | pip install https://github.com/Lightning-AI/utilities/archive/refs/heads/main.zip
77 | ```
78 |
79 | From pypi:
80 |
81 | ```bash
82 | pip install lightning_utilities
83 | ```
84 |
85 |
86 |
87 | __Usage:__
88 |
89 | Example for optional imports:
90 |
91 | ```python
92 | from lightning_utilities.core.imports import module_available
93 |
94 | if module_available("some_package.something"):
95 | from some_package import something
96 | ```
97 |
98 | ## 4. CLI `lightning_utilities.cli`
99 |
100 | The package provides common CLI commands.
101 |
102 |
103 | Installation
104 |
105 | From pypi:
106 |
107 | ```bash
108 | pip install lightning_utilities[cli]
109 | ```
110 |
111 |
112 |
113 | __Usage:__
114 |
115 | ```bash
116 | python -m lightning_utilities.cli [group] [command]
117 | ```
118 |
119 |
120 | Example for setting min versions
121 |
122 | ```console
123 | $ cat requirements/test.txt
124 | coverage>=5.0
125 | codecov>=2.1
126 | pytest>=6.0
127 | pytest-cov
128 | pytest-timeout
129 | $ python -m lightning_utilities.cli requirements set-oldest
130 | $ cat requirements/test.txt
131 | coverage==5.0
132 | codecov==2.1
133 | pytest==6.0
134 | pytest-cov
135 | pytest-timeout
136 | ```
137 |
138 |
139 |
--------------------------------------------------------------------------------
/docs/.build_docs.sh:
--------------------------------------------------------------------------------
1 | make clean
2 | make html --debug --jobs $(nproc)
3 |
--------------------------------------------------------------------------------
/docs/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yml in docs/ ~ need to be re-defined in RTFD settings UI
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Set the OS, Python version and other tools you might need
9 | build:
10 | os: ubuntu-22.04
11 | tools:
12 | python: "3.10"
13 | # You can also specify other tool versions:
14 | # nodejs: "20"
15 |
16 | # Build documentation in the docs/ directory with Sphinx
17 | sphinx:
18 | configuration: docs/source/conf.py
19 | # Fail on all warnings to avoid broken references
20 | fail_on_warning: true
21 |
22 | # Optionally build your docs in additional formats such as PDF and ePub
23 | formats:
24 | - htmlzip
25 | - pdf
26 |
27 | # Optionally set the version of Python and requirements required to build your docs
28 | python:
29 | install:
30 | - requirements: requirements/_docs.txt
31 | - method: pip
32 | path: .
33 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS = -T -W
6 | SPHINXBUILD = python $(shell which sphinx-build)
7 | SOURCEDIR = source
8 | BUILDDIR = build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
20 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/_static/copybutton.js:
--------------------------------------------------------------------------------
1 | /* Copied from the official Python docs: https://docs.python.org/3/_static/copybutton.js */
2 | $(document).ready(function () {
3 | /* Add a [>>>] button on the top-right corner of code samples to hide
4 | * the >>> and ... prompts and the output and thus make the code
5 | * copyable. */
6 | var div = $(
7 | ".highlight-python .highlight," +
8 | ".highlight-python3 .highlight," +
9 | ".highlight-pycon .highlight," +
10 | ".highlight-default .highlight",
11 | );
12 | var pre = div.find("pre");
13 |
14 | // get the styles from the current theme
15 | pre.parent().parent().css("position", "relative");
16 | var hide_text = "Hide the prompts and output";
17 | var show_text = "Show the prompts and output";
18 | var border_width = pre.css("border-top-width");
19 | var border_style = pre.css("border-top-style");
20 | var border_color = pre.css("border-top-color");
21 | var button_styles = {
22 | cursor: "pointer",
23 | position: "absolute",
24 | top: "0",
25 | right: "0",
26 | "border-color": border_color,
27 | "border-style": border_style,
28 | "border-width": border_width,
29 | color: border_color,
30 | "text-size": "75%",
31 | "font-family": "monospace",
32 | "padding-left": "0.2em",
33 | "padding-right": "0.2em",
34 | "border-radius": "0 3px 0 0",
35 | };
36 |
37 | // create and add the button to all the code blocks that contain >>>
38 | div.each(function (index) {
39 | var jthis = $(this);
40 | if (jthis.find(".gp").length > 0) {
41 | var button = $('>>>');
42 | button.css(button_styles);
43 | button.attr("title", hide_text);
44 | button.data("hidden", "false");
45 | jthis.prepend(button);
46 | }
47 | // tracebacks (.gt) contain bare text elements that need to be
48 | // wrapped in a span to work with .nextUntil() (see later)
49 | jthis
50 | .find("pre:has(.gt)")
51 | .contents()
52 | .filter(function () {
53 | return this.nodeType == 3 && this.data.trim().length > 0;
54 | })
55 | .wrap("");
56 | });
57 |
58 | // define the behavior of the button when it's clicked
59 | $(".copybutton").click(function (e) {
60 | e.preventDefault();
61 | var button = $(this);
62 | if (button.data("hidden") === "false") {
63 | // hide the code output
64 | button.parent().find(".go, .gp, .gt").hide();
65 | button.next("pre").find(".gt").nextUntil(".gp, .go").css("visibility", "hidden");
66 | button.css("text-decoration", "line-through");
67 | button.attr("title", show_text);
68 | button.data("hidden", "true");
69 | } else {
70 | // show the code output
71 | button.parent().find(".go, .gp, .gt").show();
72 | button.next("pre").find(".gt").nextUntil(".gp, .go").css("visibility", "visible");
73 | button.css("text-decoration", "none");
74 | button.attr("title", hide_text);
75 | button.data("hidden", "false");
76 | }
77 | });
78 | });
79 |
--------------------------------------------------------------------------------
/docs/source/_static/images/icon.svg:
--------------------------------------------------------------------------------
1 |
2 |
63 |
--------------------------------------------------------------------------------
/docs/source/_static/images/logo-large.svg:
--------------------------------------------------------------------------------
1 |
2 |
62 |
--------------------------------------------------------------------------------
/docs/source/_static/images/logo-small.svg:
--------------------------------------------------------------------------------
1 |
2 |
63 |
--------------------------------------------------------------------------------
/docs/source/_static/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/utilities/2658ff8b341fc0f30e9d9e9042a9d6c0e4de49a2/docs/source/_static/images/logo.png
--------------------------------------------------------------------------------
/docs/source/_static/images/logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
71 |
--------------------------------------------------------------------------------
/docs/source/_templates/theme_variables.jinja:
--------------------------------------------------------------------------------
1 | {%- set external_urls = {
2 | 'github': 'https://github.com/Lightning-AI/utilities',
3 | 'github_issues': 'https://github.com/Lightning-AI/utilities/issues',
4 | 'contributing': 'https://github.com/Lightning-AI/lightning/blob/master/CONTRIBUTING.md',
5 | 'governance': 'https://github.com/Lightning-AI/lightning/blob/master/governance.md',
6 | 'docs': 'https://dev-toolbox.rtfd.io/en/latest',
7 | 'twitter': 'https://twitter.com/PyTorchLightnin',
8 | 'discuss': 'https://pytorch-lightning.slack.com',
9 | 'home': 'https://lightning-tools.rtfd.io/en/latest/',
10 | 'get_started': 'https://lightning-tools.readthedocs.io/en/latest/introduction_guide.html',
11 | 'features': 'https://lightning-tools.rtfd.io/en/latest/',
12 | 'blog': 'https://www.pytorchlightning.ai/blog',
13 | 'support': 'https://lightning-tools.rtfd.io/en/latest/',
14 | }
15 | -%}
16 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. LightningAI-DevToolbox documentation master file, created by
2 | sphinx-quickstart on Wed Mar 25 21:34:07 2020.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Lightning-DevToolbox documentation
7 | ==================================
8 |
9 | .. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/Lightning.gif
10 | :alt: What is Lightning gif.
11 | :width: 80 %
12 |
13 | .. toctree::
14 | :maxdepth: 1
15 | :name: content
16 | :caption: Overview
17 |
18 | Utilities readme
19 |
20 |
21 | Indices and tables
22 | ==================
23 |
24 | * :ref:`genindex`
25 | * :ref:`modindex`
26 | * :ref:`search`
27 |
--------------------------------------------------------------------------------
/docs/source/test-page.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | Testing page
4 | ============
5 |
6 | This is some page serving exclusively for testing purposes.
7 |
8 | Link to scikit-learn stable documentation: https://scikit-learn.org/stable/index.html
9 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools",
4 | "wheel",
5 | ]
6 |
7 | [tool.ruff]
8 | target-version = "py39"
9 |
10 | line-length = 120
11 | format.preview = true
12 | lint.select = [
13 | "D", # see: https://pypi.org/project/pydocstyle
14 | "E",
15 | "F", # see: https://pypi.org/project/pyflakes
16 | "I", #see: https://pypi.org/project/isort/
17 | "N", # see: https://pypi.org/project/pep8-naming
18 | "RUF100", # see: https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
19 | "S", # see: https://pypi.org/project/flake8-bandit
20 | "UP", # see: https://docs.astral.sh/ruff/rules/#pyupgrade-up
21 | "W", # see: https://pypi.org/project/pycodestyle
22 | ]
23 | lint.extend-select = [
24 | "A", # see: https://pypi.org/project/flake8-builtins
25 | "ANN", # see: https://pypi.org/project/flake8-annotations
26 | "B", # see: https://pypi.org/project/flake8-bugbear
27 | "C4", # see: https://pypi.org/project/flake8-comprehensions
28 | "EXE", # see: https://pypi.org/project/flake8-executable
29 | "ISC", # see: https://pypi.org/project/flake8-implicit-str-concat
30 | "PIE", # see: https://pypi.org/project/flake8-pie
31 | "PLE", # see: https://pypi.org/project/pylint/
32 | "PT", # see: https://pypi.org/project/flake8-pytest-style
33 | "Q", # see: https://pypi.org/project/flake8-quotes
34 | "RET", # see: https://pypi.org/project/flake8-return
35 | "RUF", # Ruff-specific rules
36 | "SIM", # see: https://pypi.org/project/flake8-simplify
37 | "T10", # see: https://pypi.org/project/flake8-debugger
38 | "TID", # see: https://pypi.org/project/flake8-tidy-imports/
39 | "YTT", # see: https://pypi.org/project/flake8-2020
40 | ]
41 | lint.ignore = [
42 | "E731",
43 | ]
44 | lint.per-file-ignores."__about__.py" = [
45 | "D100",
46 | ]
47 | lint.per-file-ignores."__init__.py" = [
48 | "D100",
49 | ]
50 | lint.per-file-ignores."docs/source/conf.py" = [
51 | "A001",
52 | "ANN001",
53 | "ANN201",
54 | "D100",
55 | "D103",
56 | ]
57 | lint.per-file-ignores."setup.py" = [
58 | "ANN202",
59 | "D100",
60 | "SIM115",
61 | ]
62 | lint.per-file-ignores."src/**" = [
63 | "ANN101", # Missing type annotation for `self` in method
64 | "ANN102", # Missing type annotation for `cls` in classmethod
65 | "ANN401", # Dynamically typed expressions (typing.Any)
66 | "B905", # `zip()` without an explicit `strict=` parameter
67 | "D100", # Missing docstring in public module
68 | "D107", # Missing docstring in `__init__`
69 | ]
70 | lint.per-file-ignores."tests/**" = [
71 | "ANN001", # Missing type annotation for function argument
72 | "ANN101", # Missing type annotation for `self` in method
73 | "ANN201", # Missing return type annotation for public function
74 | "ANN202", # Missing return type annotation for private function
75 | "ANN204", # Missing return type annotation for special method
76 | "ANN401", # Dynamically typed expressions (typing.Any)
77 | "B028", # No explicit `stacklevel` keyword argument found
78 | "B905", # `zip()` without an explicit `strict=` parameter
79 | "D100", # Missing docstring in public module
80 | "D101", # Missing docstring in public class
81 | "D102", # Missing docstring in public method
82 | "D103", # Missing docstring in public function
83 | "D104", # Missing docstring in public package
84 | "D105", # Missing docstring in magic method
85 | "D107", # Missing docstring in `__init__`
86 | "S101", # Use of `assert` detected
87 | "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
88 | ]
89 | lint.mccabe.max-complexity = 10
90 | # Use Google-style docstrings.
91 | lint.pydocstyle.convention = "google"
92 |
93 | [tool.codespell]
94 | #skip = '*.py'
95 | quiet-level = 3
96 | # comma separated list of words; waiting for:
97 | # https://github.com/codespell-project/codespell/issues/2839#issuecomment-1731601603
98 | # also adding links until they ignored by its: nature
99 | # https://github.com/codespell-project/codespell/issues/2243#issuecomment-1732019960
100 | #ignore-words-list = ""
101 |
102 | [tool.docformatter]
103 | recursive = true
104 | # this need to be shorter as some docstings are r"""...
105 | wrap-summaries = 119
106 | wrap-descriptions = 120
107 | blank = true
108 |
109 | [tool.check-manifest]
110 | ignore = [
111 | "*.yml",
112 | ".github",
113 | ".github/*",
114 | ]
115 |
116 | [tool.pytest.ini_options]
117 | norecursedirs = [
118 | ".git",
119 | ".github",
120 | "dist",
121 | "build",
122 | "docs",
123 | ]
124 | addopts = [
125 | "--strict-markers",
126 | "--doctest-modules",
127 | "--durations=25",
128 | "--color=yes",
129 | "--disable-pytest-warnings",
130 | ]
131 | markers = [
132 | "online: run tests that require internet connection",
133 | ]
134 | filterwarnings = [
135 | "error::FutureWarning",
136 | ] # todo: "error::DeprecationWarning"
137 | xfail_strict = true
138 |
139 | [tool.coverage.report]
140 | exclude_lines = [
141 | "pragma: no cover",
142 | "pass",
143 | ]
144 | #[tool.coverage.run]
145 | #parallel = true
146 | #concurrency = ['thread']
147 | #relative_files = true
148 |
149 | [tool.mypy]
150 | files = [
151 | "src/lightning_utilities",
152 | ]
153 | disallow_untyped_defs = true
154 | ignore_missing_imports = true
155 |
--------------------------------------------------------------------------------
/requirements/_docs.txt:
--------------------------------------------------------------------------------
1 | sphinx >=5.0,<6.0
2 | myst-parser >=1.0.0, <2.0.0
3 | nbsphinx >=0.8.5
4 | ipython[notebook]
5 | pandoc >=1.0
6 | docutils >=0.16
7 | # https://github.com/jupyterlab/jupyterlab_pygments/issues/5
8 | pygments >=2.4.1
9 | #sphinxcontrib-fulltoc >=1.0
10 | #sphinxcontrib-mockautodoc
11 |
12 | pt-lightning-sphinx-theme @ https://github.com/Lightning-AI/lightning_sphinx_theme/archive/master.zip
13 | sphinx-autodoc-typehints >=1.0
14 | sphinx-paramlinks >=0.5.1
15 | sphinx-togglebutton >=0.2
16 | sphinx-copybutton >=0.3
17 |
--------------------------------------------------------------------------------
/requirements/_tests.txt:
--------------------------------------------------------------------------------
1 | coverage ==7.2.7; python_version < '3.8'
2 | coverage ==7.8.*; python_version >= '3.8'
3 | pytest ==7.4.3; python_version < '3.8'
4 | pytest ==8.3.*; python_version >= '3.8'
5 | pytest-cov ==6.1.1
6 | pytest-timeout ==2.4.0
7 |
--------------------------------------------------------------------------------
/requirements/cli.txt:
--------------------------------------------------------------------------------
1 | jsonargparse[signatures] >=4.38.0
2 |
--------------------------------------------------------------------------------
/requirements/core.txt:
--------------------------------------------------------------------------------
1 | importlib-metadata >=4.0.0; python_version < '3.8'
2 | packaging >=17.1
3 | setuptools
4 | typing_extensions
5 |
--------------------------------------------------------------------------------
/requirements/docs.txt:
--------------------------------------------------------------------------------
1 | requests >=2.0.0
2 |
--------------------------------------------------------------------------------
/requirements/gha-package.txt:
--------------------------------------------------------------------------------
1 | twine ==6.1.*
2 | setuptools ==80.9.*
3 | wheel ==0.45.*
4 | build ==1.2.*
5 | importlib_metadata ==8.7.*
6 | packaging ==25.0
7 |
--------------------------------------------------------------------------------
/requirements/gha-schema.txt:
--------------------------------------------------------------------------------
1 | check-jsonschema ==0.33.*
2 |
--------------------------------------------------------------------------------
/requirements/typing.txt:
--------------------------------------------------------------------------------
1 | mypy>=1.0.0
2 |
3 | types-setuptools
4 |
--------------------------------------------------------------------------------
/scripts/adjust-torch-versions.py:
--------------------------------------------------------------------------------
1 | # Licensed under the Apache License, Version 2.0 (the "License");
2 | # http://www.apache.org/licenses/LICENSE-2.0
3 | #
4 | """Adjusting version across PTorch ecosystem."""
5 |
6 | import logging
7 | import os
8 | import re
9 | import sys
10 | from typing import Optional
11 |
12 |
13 | def _determine_torchaudio(torch_version: str) -> str:
14 | """Determine the torchaudio version based on the torch version.
15 |
16 | >>> _determine_torchaudio("1.9.0")
17 | '0.9.0'
18 | >>> _determine_torchaudio("2.4.1")
19 | '2.4.1'
20 | >>> _determine_torchaudio("1.8.2")
21 | '0.9.1'
22 |
23 | """
24 | _version_exceptions = {
25 | "2.0.1": "2.0.2",
26 | "2.0.0": "2.0.1",
27 | "1.8.2": "0.9.1",
28 | }
29 | if torch_version in _version_exceptions:
30 | return _version_exceptions[torch_version]
31 | ver_major, ver_minor, ver_bugfix = map(int, torch_version.split("."))
32 | ta_ver_array = [ver_major, ver_minor, ver_bugfix]
33 | if ver_major == 1:
34 | ta_ver_array[0] = 0
35 | ta_ver_array[2] = ver_bugfix
36 | return ".".join(map(str, ta_ver_array))
37 |
38 |
39 | def _determine_torchtext(torch_version: str) -> str:
40 | """Determine the torchtext version based on the torch version.
41 |
42 | >>> _determine_torchtext("1.9.0")
43 | '0.10.0'
44 | >>> _determine_torchtext("2.4.1")
45 | '0.18.0'
46 | >>> _determine_torchtext("1.8.2")
47 | '0.9.1'
48 |
49 | """
50 | _version_exceptions = {
51 | "2.0.1": "0.15.2",
52 | "2.0.0": "0.15.1",
53 | "1.8.2": "0.9.1",
54 | }
55 | if torch_version in _version_exceptions:
56 | return _version_exceptions[torch_version]
57 | ver_major, ver_minor, ver_bugfix = map(int, torch_version.split("."))
58 | tt_ver_array = [0, 0, 0]
59 | if ver_major == 1:
60 | tt_ver_array[1] = ver_minor + 1
61 | tt_ver_array[2] = ver_bugfix
62 | elif ver_major == 2:
63 | if ver_minor >= 3:
64 | tt_ver_array[1] = 18
65 | else:
66 | tt_ver_array[1] = ver_minor + 15
67 | tt_ver_array[2] = ver_bugfix
68 | else:
69 | raise ValueError(f"Invalid torch version: {torch_version}")
70 | return ".".join(map(str, tt_ver_array))
71 |
72 |
73 | def _determine_torchvision(torch_version: str) -> str:
74 | """Determine the torchvision version based on the torch version.
75 |
76 | >>> _determine_torchvision("1.9.0")
77 | '0.10.0'
78 | >>> _determine_torchvision("2.4.1")
79 | '0.19.1'
80 | >>> _determine_torchvision("2.0.1")
81 | '0.15.2'
82 |
83 | """
84 | _version_exceptions = {
85 | "2.0.1": "0.15.2",
86 | "2.0.0": "0.15.1",
87 | "1.10.2": "0.11.3",
88 | "1.10.1": "0.11.2",
89 | "1.10.0": "0.11.1",
90 | "1.8.2": "0.9.1",
91 | }
92 | if torch_version in _version_exceptions:
93 | return _version_exceptions[torch_version]
94 | ver_major, ver_minor, ver_bugfix = map(int, torch_version.split("."))
95 | tv_ver_array = [0, 0, 0]
96 | if ver_major == 1:
97 | tv_ver_array[1] = ver_minor + 1
98 | elif ver_major == 2:
99 | tv_ver_array[1] = ver_minor + 15
100 | else:
101 | raise ValueError(f"Invalid torch version: {torch_version}")
102 | tv_ver_array[2] = ver_bugfix
103 | return ".".join(map(str, tv_ver_array))
104 |
105 |
106 | def find_latest(ver: str) -> dict[str, str]:
107 | """Find the latest version.
108 |
109 | >>> from pprint import pprint
110 | >>> pprint(find_latest("2.4.1"))
111 | {'torch': '2.4.1',
112 | 'torchaudio': '2.4.1',
113 | 'torchtext': '0.18.0',
114 | 'torchvision': '0.19.1'}
115 | >>> pprint(find_latest("2.1"))
116 | {'torch': '2.1.0',
117 | 'torchaudio': '2.1.0',
118 | 'torchtext': '0.16.0',
119 | 'torchvision': '0.16.0'}
120 |
121 | """
122 | # drop all except semantic version
123 | ver = re.search(r"([\.\d]+)", ver).groups()[0]
124 | # in case there remaining dot at the end - e.g "1.9.0.dev20210504"
125 | ver = ver[:-1] if ver[-1] == "." else ver
126 | if not re.match(r"\d+\.\d+\.\d+", ver):
127 | ver += ".0" # add missing bugfix
128 | logging.debug(f"finding ecosystem versions for: {ver}")
129 |
130 | # find first match
131 | return {
132 | "torch": ver,
133 | "torchvision": _determine_torchvision(ver),
134 | "torchtext": _determine_torchtext(ver),
135 | "torchaudio": _determine_torchaudio(ver),
136 | }
137 |
138 |
139 | def adjust(requires: list[str], pytorch_version: Optional[str] = None) -> list[str]:
140 | """Adjust the versions to be paired within pytorch ecosystem.
141 |
142 | >>> from pprint import pprint
143 | >>> pprint(adjust(["torch>=1.9.0", "torchvision>=0.10.0", "torchtext>=0.10.0", "torchaudio>=0.9.0"], "2.1.0"))
144 | ['torch==2.1.0',
145 | 'torchvision==0.16.0',
146 | 'torchtext==0.16.0',
147 | 'torchaudio==2.1.0']
148 |
149 | """
150 | if not pytorch_version:
151 | import torch
152 |
153 | pytorch_version = torch.__version__
154 | if not pytorch_version:
155 | raise ValueError(f"invalid torch: {pytorch_version}")
156 |
157 | requires_ = []
158 | options = find_latest(pytorch_version)
159 | logging.debug(f"determined ecosystem alignment: {options}")
160 | for req in requires:
161 | req_split = req.strip().split("#", maxsplit=1)
162 | # anything before fst # shall be requirements
163 | req = req_split[0].strip()
164 | # anything after # in the line is comment
165 | comment = "" if len(req_split) < 2 else " #" + req_split[1]
166 | if not req:
167 | # if only comment make it short
168 | requires_.append(comment.strip())
169 | continue
170 | for lib, version in options.items():
171 | replace = f"{lib}=={version}" if version else ""
172 | req = re.sub(rf"\b{lib}(?![-_\w]).*", replace, req)
173 | requires_.append(req + comment.rstrip())
174 |
175 | return requires_
176 |
177 |
178 | def _offset_print(reqs: list[str], offset: str = "\t|\t") -> str:
179 | """Adding offset to each line for the printing requirements."""
180 | reqs = [offset + r for r in reqs]
181 | return os.linesep.join(reqs)
182 |
183 |
184 | def main(requirements_path: str, torch_version: Optional[str] = None) -> None:
185 | """The main entry point with mapping to the CLI for positional arguments only."""
186 | # rU - universal line ending - https://stackoverflow.com/a/2717154/4521646
187 | with open(requirements_path, encoding="utf8") as fopen:
188 | requirements = fopen.readlines()
189 | requirements = adjust(requirements, torch_version)
190 | logging.info(
191 | f"requirements_path='{requirements_path}' with arg torch_version='{torch_version}' >>\n"
192 | f"{_offset_print(requirements)}"
193 | )
194 | with open(requirements_path, "w", encoding="utf8") as fopen:
195 | fopen.writelines([r + os.linesep for r in requirements])
196 |
197 |
198 | if __name__ == "__main__":
199 | logging.basicConfig(level=logging.INFO)
200 | try:
201 | from jsonargparse import auto_cli, set_parsing_settings
202 |
203 | set_parsing_settings(parse_optionals_as_positionals=True)
204 | auto_cli(main)
205 | except (ModuleNotFoundError, ImportError):
206 | main(*sys.argv[1:])
207 |
--------------------------------------------------------------------------------
/scripts/inject-selector-script.py:
--------------------------------------------------------------------------------
1 | # Licensed under the Apache License, Version 2.0 (the "License");
2 | # http://www.apache.org/licenses/LICENSE-2.0
3 | #
4 | """Simple script to inject a custom JS script into all HTML pages in given folder.
5 |
6 | Sample usage:
7 | $ python scripts/inject-selector-script.py "/path/to/folder" torchmetrics
8 |
9 | """
10 |
11 | import logging
12 | import os
13 | import sys
14 |
15 |
16 | def inject_selector_script_into_html_file(file_path: str, script_url: str) -> None:
17 | """Inject a custom JS script into the given HTML file."""
18 | with open(file_path) as fopen:
19 | html_content = fopen.read()
20 | html_content = html_content.replace(
21 | "",
22 | f'{os.linesep}',
23 | )
24 | with open(file_path, "w") as fopen:
25 | fopen.write(html_content)
26 |
27 |
28 | def main(folder: str, selector_name: str) -> None:
29 | """Inject a custom JS script into all HTML files in the given folder."""
30 | # Sample: https://lightning.ai/docs/torchmetrics/version-selector.js
31 | script_url = f"https://lightning.ai/docs/{selector_name}/version-selector.js"
32 | html_files = [
33 | os.path.join(root, file) for root, _, files in os.walk(folder) for file in files if file.endswith(".html")
34 | ]
35 | for file_path in html_files:
36 | inject_selector_script_into_html_file(file_path, script_url)
37 |
38 |
39 | if __name__ == "__main__":
40 | logging.basicConfig(level=logging.INFO)
41 | try:
42 | from jsonargparse import auto_cli, set_parsing_settings
43 |
44 | set_parsing_settings(parse_optionals_as_positionals=True)
45 | auto_cli(main)
46 | except (ModuleNotFoundError, ImportError):
47 | main(*sys.argv[1:])
48 |
--------------------------------------------------------------------------------
/scripts/run_standalone_tests.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright The Lightning AI team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # THIS FILE ASSUMES IT IS RUN INSIDE THE tests DIRECTORY.
17 |
18 | # Batch size for testing: Determines how many standalone test invocations run in parallel
19 | # It can be set through the env variable NUM_PARALLEL_TESTS and defaults to 5 if not set
20 | test_batch_size="${NUM_PARALLEL_TESTS:-5}"
21 |
22 | # Source directory for coverage runs can be set with CODECOV_SOURCE.
23 | codecov_source="${COVERAGE_SOURCE}"
24 |
25 | # The test directory is passed as the first argument to the script
26 | test_dir=$1 # parse the first argument
27 |
28 | # There is also timeout for the tests.
29 | # It can be set through the env variable TEST_TIMEOUT and defaults to 1200 seconds.
30 | test_timeout="${TEST_TIMEOUT:-1200}"
31 |
32 | # Temporary file to store the collected tests
33 | COLLECTED_TESTS_FILE="collected_tests.txt"
34 |
35 | ls -lh . # show the contents of the directory
36 |
37 | # If codecov_source is set, prepend the coverage command
38 | if [ -n "$codecov_source" ]; then
39 | cli_coverage="-m coverage run --source ${codecov_source} --append"
40 | else # If not, just keep it empty
41 | cli_coverage=""
42 | fi
43 | # Append the common pytest arguments
44 | cli_pytest="-m pytest --no-header -v -s --color=yes --timeout=${test_timeout}"
45 |
46 | # Python arguments for running the tests and optional coverage
47 | printf "\e[35mUsing defaults: ${cli_coverage} ${cli_pytest}\e[0m\n"
48 |
49 | # Get the list of parametrizations. we need to call them separately. the last two lines are removed.
50 | # note: if there's a syntax error, this will fail with some garbled output
51 | python -um pytest ${test_dir} -q --collect-only --pythonwarnings ignore 2>&1 > $COLLECTED_TESTS_FILE
52 | # Early terminate if collection failed (e.g. syntax error)
53 | if [[ $? != 0 ]]; then
54 | cat $COLLECTED_TESTS_FILE
55 | printf "ERROR: test collection failed!\n"
56 | exit 1
57 | fi
58 |
59 | # Initialize empty array
60 | tests=()
61 |
62 | # Read from file line by line
63 | while IFS= read -r line; do
64 | # Only keep lines containing "test_"
65 | if [[ $line == *"test_"* ]]; then
66 | # Extract part after test_dir/
67 | pruned_line="${line#*${test_dir}/}"
68 | tests+=("${test_dir}/$pruned_line")
69 | fi
70 | done < $COLLECTED_TESTS_FILE
71 |
72 | # Count tests
73 | test_count=${#tests[@]}
74 |
75 | # Display results
76 | printf "\e[34m================================================================================\e[0m\n"
77 | printf "\e[34mCOLLECTED $test_count TESTS:\e[0m\n"
78 | printf "\e[34m~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\e[0m\n"
79 | printf "\e[34m%s\e[0m\n" "${tests[@]}"
80 | printf "\e[34m================================================================================\e[0m\n"
81 |
82 | # if test count is one print warning
83 | if [[ $test_count -eq 1 ]]; then
84 | printf "\e[33mWARNING: only one test found!\e[0m\n"
85 | elif [ $test_count -eq 0 ]; then
86 | printf "\e[31mERROR: no tests found!\e[0m\n"
87 | exit 1
88 | fi
89 |
90 | if [ -n "$codecov_source" ]; then
91 | coverage combine
92 | fi
93 |
94 | status=0 # aggregated script status
95 | report=() # final report
96 | pids=() # array of PID for running tests
97 | test_ids=() # array of indexes of running tests
98 | failed_tests=() # array of failed tests
99 | printf "Running $test_count tests in batches of $test_batch_size:\n"
100 | for i in "${!tests[@]}"; do
101 | test=${tests[$i]}
102 |
103 | cli_test="python "
104 | if [ -n "$codecov_source" ]; then
105 | # append cli_coverage to the test command
106 | cli_test="${cli_test} ${cli_coverage} --data-file=run-${i}.coverage"
107 | fi
108 | # add the pytest cli to the test command
109 | cli_test="${cli_test} ${cli_pytest}"
110 |
111 | printf "\e[95m* Running test $((i+1))/$test_count: $cli_test $test\e[0m\n"
112 |
113 | # execute the test in the background
114 | # redirect to a log file that buffers test output. since the tests will run in the background,
115 | # we cannot let them output to std{out,err} because the outputs would be garbled together
116 | ${cli_test} "$test" &> "parallel_test_output-$i.txt" &
117 | test_ids+=($i) # save the test's id in an array with running tests
118 | pids+=($!) # save the PID in an array with running tests
119 |
120 | # if we reached the batch size, wait for all tests to finish
121 | if (( (($i + 1) % $test_batch_size == 0) || $i == $test_count-1 )); then
122 | printf "Waiting for batch to finish: $(IFS=' '; echo "${pids[@]}")\n"
123 | # wait for running tests
124 | for j in "${!test_ids[@]}"; do
125 | i=${test_ids[$j]} # restore the global test's id
126 | pid=${pids[$j]} # restore the particular PID
127 | test=${tests[$i]} # restore the test name
128 | printf "\e[33m? Waiting for $test @ parallel_test_output-$i.txt (PID: $pid)\e[0m\n"
129 | wait -n $pid
130 | # get the exit status of the test
131 | test_status=$?
132 | # add row to the final report
133 | report+=("Ran $test >> exit:$test_status")
134 | if [[ $test_status != 0 ]]; then
135 | # add the test to the failed tests array
136 | failed_tests+=($i)
137 | # Process exited with a non-zero exit status
138 | status=$test_status
139 | fi
140 | done
141 | printf "Starting over with a new batch...\n"
142 | test_ids=() # reset the test's id array
143 | pids=() # reset the PID array
144 | fi
145 | done
146 |
147 | # print test report with exit code for each test
148 | printf "\e[35m================================================================================\e[0m\n"
149 | for line in "${report[@]}"; do
150 | if [[ "$line" == *"exit:0"* ]]; then
151 | printf "\e[32m%s\e[0m\n" "$line" # Green for lines containing exit:0
152 | else
153 | printf "\e[31m%s\e[0m\n" "$line" # Red for all other lines
154 | fi
155 | done
156 | printf "\e[35m================================================================================\e[0m\n"
157 |
158 | # print failed tests from duped logs
159 | if [[ ${#failed_tests[@]} -gt 0 ]]; then
160 | printf "\e[34mFAILED TESTS:\e[0m\n"
161 | for i in "${failed_tests[@]}"; do
162 | printf "\e[34m================================================================================\e[0m\n"
163 | printf "\e[34m=== ${tests[$i]} ===\e[0m\n"
164 | printf "\e[34m~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\e[0m\n\n"
165 | # show the output of the failed test
166 | cat "parallel_test_output-$i.txt"
167 | printf "\e[34m~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\e[0m\n"
168 | printf "\e[34m================================================================================\e[0m\n"
169 | printf '\n\n\n'
170 | done
171 | else
172 | printf "\e[32mAll tests passed!\e[0m\n"
173 | fi
174 |
175 | # exit with the worse test result
176 | exit $status
177 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import glob
3 | import os
4 | from importlib.util import module_from_spec, spec_from_file_location
5 |
6 | from pkg_resources import parse_requirements
7 | from setuptools import find_packages, setup
8 |
9 | _PATH_ROOT = os.path.realpath(os.path.dirname(__file__))
10 | _PATH_SOURCE = os.path.join(_PATH_ROOT, "src")
11 | _PATH_REQUIRE = os.path.join(_PATH_ROOT, "requirements")
12 |
13 |
14 | def _load_py_module(fname: str, pkg: str = "lightning_utilities"):
15 | spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_SOURCE, pkg, fname))
16 | py = module_from_spec(spec)
17 | spec.loader.exec_module(py)
18 | return py
19 |
20 |
21 | about = _load_py_module("__about__.py")
22 |
23 | # load basic requirements
24 | with open(os.path.join(_PATH_REQUIRE, "core.txt")) as fp:
25 | requirements = list(map(str, parse_requirements(fp.readlines())))
26 |
27 |
28 | # make extras as automated loading
29 | def _requirement_extras(path_req: str = _PATH_REQUIRE) -> dict:
30 | extras = {}
31 | for fpath in glob.glob(os.path.join(path_req, "*.txt")):
32 | fname = os.path.basename(fpath)
33 | if fname.startswith(("_", "gha-")):
34 | continue
35 | if fname in ("core.txt",):
36 | continue
37 | name, _ = os.path.splitext(fname)
38 | with open(fpath) as fp:
39 | reqs = parse_requirements(fp.readlines())
40 | extras[name] = list(map(str, reqs))
41 | return extras
42 |
43 |
44 | # loading readme as description
45 | with open(os.path.join(_PATH_ROOT, "README.md")) as fp:
46 | readme = fp.read()
47 |
48 | setup(
49 | name="lightning-utilities",
50 | version=about.__version__,
51 | description=about.__docs__,
52 | author=about.__author__,
53 | author_email=about.__author_email__,
54 | url=about.__homepage__,
55 | download_url="https://github.com/Lightning-AI/utilities",
56 | license=about.__license__, # Should be a license identifier like "MIT"
57 | license_files=["LICENSE"], # Path to your license file
58 | packages=find_packages(where="src"),
59 | package_dir={"": "src"},
60 | long_description=readme,
61 | long_description_content_type="text/markdown",
62 | include_package_data=True,
63 | zip_safe=False,
64 | keywords=["Utilities", "DevOps", "CI/CD"],
65 | python_requires=">=3.9",
66 | setup_requires=[],
67 | install_requires=requirements,
68 | extras_require=_requirement_extras(),
69 | project_urls={
70 | "Bug Tracker": "https://github.com/Lightning-AI/utilities/issues",
71 | "Documentation": "https://dev-toolbox.rtfd.io/en/latest/", # TODO: Update domain
72 | "Source Code": "https://github.com/Lightning-AI/utilities",
73 | },
74 | classifiers=[
75 | "Environment :: Console",
76 | "Natural Language :: English",
77 | # How mature is this project? Common values are
78 | # 3 - Alpha, 4 - Beta, 5 - Production/Stable
79 | "Development Status :: 3 - Alpha",
80 | # Indicate who your project is intended for
81 | "Intended Audience :: Developers",
82 | # Pick your license as you wish
83 | # 'License :: OSI Approved :: BSD License',
84 | "Operating System :: OS Independent",
85 | "Programming Language :: Python :: 3",
86 | "Programming Language :: Python :: 3.8",
87 | "Programming Language :: Python :: 3.9",
88 | "Programming Language :: Python :: 3.10",
89 | "Programming Language :: Python :: 3.11",
90 | "Programming Language :: Python :: 3.12",
91 | "Programming Language :: Python :: 3.13",
92 | ],
93 | )
94 |
--------------------------------------------------------------------------------
/src/lightning_utilities/__about__.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | __version__ = "0.15.0dev"
4 | __author__ = "Lightning AI et al."
5 | __author_email__ = "pytorch@lightning.ai"
6 | __license__ = "Apache-2.0"
7 | __copyright__ = f"Copyright (c) 2022-{time.strftime('%Y')}, {__author__}."
8 | __homepage__ = "https://github.com/Lightning-AI/utilities"
9 | __docs__ = "Lightning toolbox for across the our ecosystem."
10 | __long_doc__ = """
11 | This package allows for sharing GitHub workflows, CI/CD assistance actions, and Python utilities across the Lightning
12 | ecosystem - projects.
13 | """
14 |
15 | __all__ = [
16 | "__author__",
17 | "__author_email__",
18 | "__copyright__",
19 | "__docs__",
20 | "__homepage__",
21 | "__license__",
22 | "__version__",
23 | ]
24 |
--------------------------------------------------------------------------------
/src/lightning_utilities/__init__.py:
--------------------------------------------------------------------------------
1 | """Root package info."""
2 |
3 | import os
4 |
5 | from lightning_utilities.__about__ import * # noqa: F403
6 | from lightning_utilities.core.apply_func import apply_to_collection
7 | from lightning_utilities.core.enums import StrEnum
8 | from lightning_utilities.core.imports import compare_version, module_available
9 | from lightning_utilities.core.overrides import is_overridden
10 | from lightning_utilities.core.rank_zero import WarningCache
11 |
12 | _PACKAGE_ROOT = os.path.dirname(__file__)
13 | _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
14 |
15 |
16 | __all__ = [
17 | "StrEnum",
18 | "WarningCache",
19 | "apply_to_collection",
20 | "compare_version",
21 | "is_overridden",
22 | "module_available",
23 | ]
24 |
--------------------------------------------------------------------------------
/src/lightning_utilities/cli/__init__.py:
--------------------------------------------------------------------------------
1 | """CLI root."""
2 |
--------------------------------------------------------------------------------
/src/lightning_utilities/cli/__main__.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # http://www.apache.org/licenses/LICENSE-2.0
4 | #
5 |
6 | import lightning_utilities
7 | from lightning_utilities.cli.dependencies import (
8 | prune_packages_in_requirements,
9 | replace_oldest_version,
10 | replace_package_in_requirements,
11 | )
12 |
13 |
14 | def _get_version() -> None:
15 | """Prints the version of the lightning_utilities package."""
16 | print(lightning_utilities.__version__)
17 |
18 |
19 | def main() -> None:
20 | """CLI entry point."""
21 | from jsonargparse import auto_cli, set_parsing_settings
22 |
23 | set_parsing_settings(parse_optionals_as_positionals=True)
24 | auto_cli({
25 | "requirements": {
26 | "_help": "Manage requirements files.",
27 | "prune-pkgs": prune_packages_in_requirements,
28 | "set-oldest": replace_oldest_version,
29 | "replace-pkg": replace_package_in_requirements,
30 | },
31 | "version": _get_version,
32 | })
33 |
34 |
35 | if __name__ == "__main__":
36 | main()
37 |
--------------------------------------------------------------------------------
/src/lightning_utilities/cli/dependencies.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # http://www.apache.org/licenses/LICENSE-2.0
4 | #
5 | import glob
6 | import os.path
7 | import re
8 | from collections.abc import Sequence
9 | from pprint import pprint
10 | from typing import Union
11 |
12 | REQUIREMENT_ROOT = "requirements.txt"
13 | REQUIREMENT_FILES_ALL: list = glob.glob(os.path.join("requirements", "*.txt"))
14 | REQUIREMENT_FILES_ALL += glob.glob(os.path.join("requirements", "**", "*.txt"), recursive=True)
15 | if os.path.isfile(REQUIREMENT_ROOT):
16 | REQUIREMENT_FILES_ALL += [REQUIREMENT_ROOT]
17 |
18 |
19 | def prune_packages_in_requirements(
20 | packages: Union[str, Sequence[str]], req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL
21 | ) -> None:
22 | """Remove some packages from given requirement files."""
23 | if isinstance(packages, str):
24 | packages = [packages]
25 | if isinstance(req_files, str):
26 | req_files = [req_files]
27 | for req in req_files:
28 | _prune_packages(req, packages)
29 |
30 |
31 | def _prune_packages(req_file: str, packages: Sequence[str]) -> None:
32 | """Remove some packages from given requirement files."""
33 | with open(req_file) as fp:
34 | lines = fp.readlines()
35 |
36 | if isinstance(packages, str):
37 | packages = [packages]
38 | for pkg in packages:
39 | lines = [ln for ln in lines if not ln.startswith(pkg)]
40 | pprint(lines)
41 |
42 | with open(req_file, "w") as fp:
43 | fp.writelines(lines)
44 |
45 |
46 | def _replace_min(fname: str) -> None:
47 | with open(fname) as fopen:
48 | req = fopen.read().replace(">=", "==")
49 | with open(fname, "w") as fw:
50 | fw.write(req)
51 |
52 |
53 | def replace_oldest_version(req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL) -> None:
54 | """Replace the min package version by fixed one."""
55 | if isinstance(req_files, str):
56 | req_files = [req_files]
57 | for fname in req_files:
58 | _replace_min(fname)
59 |
60 |
61 | def _replace_package_name(requirements: list[str], old_package: str, new_package: str) -> list[str]:
62 | """Replace one package by another with same version in given requirement file.
63 |
64 | >>> _replace_package_name(["torch>=1.0 # comment", "torchvision>=0.2", "torchtext <0.3"], "torch", "pytorch")
65 | ['pytorch>=1.0 # comment', 'torchvision>=0.2', 'torchtext <0.3']
66 |
67 | """
68 | for i, req in enumerate(requirements):
69 | requirements[i] = re.sub(r"^" + re.escape(old_package) + r"(?=[ <=>#]|$)", new_package, req)
70 | return requirements
71 |
72 |
73 | def replace_package_in_requirements(
74 | old_package: str, new_package: str, req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL
75 | ) -> None:
76 | """Replace one package by another with same version in given requirement files."""
77 | if isinstance(req_files, str):
78 | req_files = [req_files]
79 | for fname in req_files:
80 | with open(fname) as fopen:
81 | reqs = fopen.readlines()
82 | reqs = _replace_package_name(reqs, old_package, new_package)
83 | with open(fname, "w") as fw:
84 | fw.writelines(reqs)
85 |
--------------------------------------------------------------------------------
/src/lightning_utilities/core/__init__.py:
--------------------------------------------------------------------------------
1 | """Core utilities."""
2 |
3 | from lightning_utilities.core.apply_func import apply_to_collection
4 | from lightning_utilities.core.enums import StrEnum
5 | from lightning_utilities.core.imports import compare_version, module_available
6 | from lightning_utilities.core.overrides import is_overridden
7 | from lightning_utilities.core.rank_zero import WarningCache
8 |
9 | __all__ = [
10 | "StrEnum",
11 | "WarningCache",
12 | "apply_to_collection",
13 | "compare_version",
14 | "is_overridden",
15 | "module_available",
16 | ]
17 |
--------------------------------------------------------------------------------
/src/lightning_utilities/core/enums.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # http://www.apache.org/licenses/LICENSE-2.0
4 | #
5 | import warnings
6 | from enum import Enum
7 | from typing import Optional
8 |
9 | from typing_extensions import Literal
10 |
11 |
12 | class StrEnum(str, Enum):
13 | """Type of any enumerator with allowed comparison to string invariant to cases.
14 |
15 | >>> class MySE(StrEnum):
16 | ... t1 = "T-1"
17 | ... t2 = "T-2"
18 | >>> MySE("T-1") == MySE.t1
19 | True
20 | >>> MySE.from_str("t-2", source="value") == MySE.t2
21 | True
22 | >>> MySE.from_str("t-2", source="value")
23 |
24 | >>> MySE.from_str("t-3", source="any")
25 | Traceback (most recent call last):
26 | ...
27 | ValueError: Invalid match: expected one of ['t1', 't2', 'T-1', 'T-2'], but got t-3.
28 |
29 | """
30 |
31 | @classmethod
32 | def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> "StrEnum":
33 | """Create ``StrEnum`` from a string matching the key or value.
34 |
35 | Args:
36 | value: matching string
37 | source: compare with:
38 |
39 | - ``"key"``: validates only from the enum keys, typical alphanumeric with "_"
40 | - ``"value"``: validates only from the values, could be any string
41 | - ``"any"``: validates with any key or value, but key has priority
42 |
43 | Raises:
44 | ValueError:
45 | if requested string does not match any option based on selected source.
46 |
47 | """
48 | if source in ("key", "any"):
49 | for enum_key in cls.__members__:
50 | if enum_key.lower() == value.lower():
51 | return cls[enum_key]
52 | if source in ("value", "any"):
53 | for enum_key, enum_val in cls.__members__.items():
54 | if enum_val == value:
55 | return cls[enum_key]
56 | raise ValueError(f"Invalid match: expected one of {cls._allowed_matches(source)}, but got {value}.")
57 |
58 | @classmethod
59 | def try_from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> Optional["StrEnum"]:
60 | """Try to create emun and if it does not match any, return `None`."""
61 | try:
62 | return cls.from_str(value, source)
63 | except ValueError:
64 | warnings.warn( # noqa: B028
65 | UserWarning(f"Invalid string: expected one of {cls._allowed_matches(source)}, but got {value}.")
66 | )
67 | return None
68 |
69 | @classmethod
70 | def _allowed_matches(cls, source: str) -> list[str]:
71 | keys, vals = [], []
72 | for enum_key, enum_val in cls.__members__.items():
73 | keys.append(enum_key)
74 | vals.append(enum_val.value)
75 | if source == "key":
76 | return keys
77 | if source == "value":
78 | return vals
79 | return keys + vals
80 |
81 | def __eq__(self, other: object) -> bool:
82 | """Compare two instances."""
83 | if isinstance(other, Enum):
84 | other = other.value
85 | return self.value.lower() == str(other).lower()
86 |
87 | def __hash__(self) -> int:
88 | """Return unique hash."""
89 | # re-enable hashtable, so it can be used as a dict key or in a set
90 | # example: set(LightningEnum)
91 | return hash(self.value.lower())
92 |
--------------------------------------------------------------------------------
/src/lightning_utilities/core/inheritance.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # http://www.apache.org/licenses/LICENSE-2.0
4 | #
5 | from collections.abc import Iterator
6 |
7 |
8 | def get_all_subclasses_iterator(cls: type) -> Iterator[type]:
9 | """Iterate over all subclasses."""
10 |
11 | def recurse(cl: type) -> Iterator[type]:
12 | for subclass in cl.__subclasses__():
13 | yield subclass
14 | yield from recurse(subclass)
15 |
16 | yield from recurse(cls)
17 |
18 |
19 | def get_all_subclasses(cls: type) -> set[type]:
20 | """List all subclasses of a class."""
21 | return set(get_all_subclasses_iterator(cls))
22 |
--------------------------------------------------------------------------------
/src/lightning_utilities/core/overrides.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # http://www.apache.org/licenses/LICENSE-2.0
4 | #
5 | from functools import partial
6 | from unittest.mock import Mock
7 |
8 |
9 | def is_overridden(method_name: str, instance: object, parent: type[object]) -> bool:
10 | """Check if a method of a given object was overwritten."""
11 | instance_attr = getattr(instance, method_name, None)
12 | if instance_attr is None:
13 | return False
14 | # `functools.wraps()` and `@contextmanager` support
15 | if hasattr(instance_attr, "__wrapped__"):
16 | instance_attr = instance_attr.__wrapped__
17 | # `Mock(wraps=...)` support
18 | if isinstance(instance_attr, Mock):
19 | # access the wrapped function
20 | instance_attr = instance_attr._mock_wraps
21 | # `partial` support
22 | elif isinstance(instance_attr, partial):
23 | instance_attr = instance_attr.func
24 | if instance_attr is None:
25 | return False
26 |
27 | parent_attr = getattr(parent, method_name, None)
28 | if parent_attr is None:
29 | raise ValueError("The parent should define the method")
30 | # `@contextmanager` support
31 | if hasattr(parent_attr, "__wrapped__"):
32 | parent_attr = parent_attr.__wrapped__
33 |
34 | return instance_attr.__code__ != parent_attr.__code__
35 |
--------------------------------------------------------------------------------
/src/lightning_utilities/core/rank_zero.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # http://www.apache.org/licenses/LICENSE-2.0
4 | #
5 | """Utilities that can be used for calling functions on a particular rank."""
6 |
7 | import logging
8 | import warnings
9 | from functools import wraps
10 | from typing import Any, Callable, Optional, TypeVar, Union
11 |
12 | from typing_extensions import ParamSpec, overload
13 |
14 | log = logging.getLogger(__name__)
15 |
16 | T = TypeVar("T")
17 | P = ParamSpec("P")
18 |
19 |
20 | @overload
21 | def rank_zero_only(fn: Callable[P, T]) -> Callable[P, Optional[T]]: ...
22 |
23 |
24 | @overload
25 | def rank_zero_only(fn: Callable[P, T], default: T) -> Callable[P, T]: ...
26 |
27 |
28 | def rank_zero_only(fn: Callable[P, T], default: Optional[T] = None) -> Callable[P, Optional[T]]:
29 | """Wrap a function to call internal function only in rank zero.
30 |
31 | Function that can be used as a decorator to enable a function/method being called only on global rank 0.
32 |
33 | """
34 |
35 | @wraps(fn)
36 | def wrapped_fn(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
37 | rank = getattr(rank_zero_only, "rank", None)
38 | if rank is None:
39 | raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
40 | if rank == 0:
41 | return fn(*args, **kwargs)
42 | return default
43 |
44 | return wrapped_fn
45 |
46 |
47 | def _debug(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None:
48 | kwargs["stacklevel"] = stacklevel
49 | log.debug(*args, **kwargs)
50 |
51 |
52 | @rank_zero_only
53 | def rank_zero_debug(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None:
54 | """Emit debug-level messages only on global rank 0."""
55 | _debug(*args, stacklevel=stacklevel, **kwargs)
56 |
57 |
58 | def _info(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None:
59 | kwargs["stacklevel"] = stacklevel
60 | log.info(*args, **kwargs)
61 |
62 |
63 | @rank_zero_only
64 | def rank_zero_info(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None:
65 | """Emit info-level messages only on global rank 0."""
66 | _info(*args, stacklevel=stacklevel, **kwargs)
67 |
68 |
69 | def _warn(message: Union[str, Warning], stacklevel: int = 2, **kwargs: Any) -> None:
70 | warnings.warn(message, stacklevel=stacklevel, **kwargs)
71 |
72 |
73 | @rank_zero_only
74 | def rank_zero_warn(message: Union[str, Warning], stacklevel: int = 4, **kwargs: Any) -> None:
75 | """Emit warn-level messages only on global rank 0."""
76 | _warn(message, stacklevel=stacklevel, **kwargs)
77 |
78 |
79 | rank_zero_deprecation_category = DeprecationWarning
80 |
81 |
82 | def rank_zero_deprecation(message: Union[str, Warning], stacklevel: int = 5, **kwargs: Any) -> None:
83 | """Emit a deprecation warning only on global rank 0."""
84 | category = kwargs.pop("category", rank_zero_deprecation_category)
85 | rank_zero_warn(message, stacklevel=stacklevel, category=category, **kwargs)
86 |
87 |
88 | def rank_prefixed_message(message: str, rank: Optional[int]) -> str:
89 | """Add a prefix with the rank to a message."""
90 | if rank is not None:
91 | # specify the rank of the process being logged
92 | return f"[rank: {rank}] {message}"
93 | return message
94 |
95 |
96 | class WarningCache(set):
97 | """Cache for warnings."""
98 |
99 | def warn(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None:
100 | """Trigger warning message."""
101 | if message not in self:
102 | self.add(message)
103 | rank_zero_warn(message, stacklevel=stacklevel, **kwargs)
104 |
105 | def deprecation(self, message: str, stacklevel: int = 6, **kwargs: Any) -> None:
106 | """Trigger deprecation message."""
107 | if message not in self:
108 | self.add(message)
109 | rank_zero_deprecation(message, stacklevel=stacklevel, **kwargs)
110 |
111 | def info(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None:
112 | """Trigger info message."""
113 | if message not in self:
114 | self.add(message)
115 | rank_zero_info(message, stacklevel=stacklevel, **kwargs)
116 |
--------------------------------------------------------------------------------
/src/lightning_utilities/docs/__init__.py:
--------------------------------------------------------------------------------
1 | """General tools for Docs."""
2 |
3 | from lightning_utilities.docs.formatting import adjust_linked_external_docs
4 | from lightning_utilities.docs.retriever import fetch_external_assets
5 |
6 | __all__ = ["adjust_linked_external_docs", "fetch_external_assets"]
7 |
--------------------------------------------------------------------------------
/src/lightning_utilities/docs/formatting.py:
--------------------------------------------------------------------------------
1 | # Licensed under the Apache License, Version 2.0 (the "License");
2 | # http://www.apache.org/licenses/LICENSE-2.0
3 | #
4 | import glob
5 | import importlib
6 | import inspect
7 | import logging
8 | import os
9 | import re
10 | import sys
11 | from collections.abc import Iterable
12 | from typing import Optional, Union
13 |
14 |
15 | def _transform_changelog(path_in: str, path_out: str) -> None:
16 | """Adjust changelog titles so not to be duplicated.
17 |
18 | Args:
19 | path_in: input MD file
20 | path_out: output also MD file
21 |
22 | """
23 | with open(path_in) as fp:
24 | chlog_lines = fp.readlines()
25 | # enrich short subsub-titles to be unique
26 | chlog_ver = ""
27 | for i, ln in enumerate(chlog_lines):
28 | if ln.startswith("## "):
29 | chlog_ver = ln[2:].split("-")[0].strip()
30 | elif ln.startswith("### "):
31 | ln = ln.replace("###", f"### {chlog_ver} -")
32 | chlog_lines[i] = ln
33 | with open(path_out, "w") as fp:
34 | fp.writelines(chlog_lines)
35 |
36 |
37 | def _linkcode_resolve(
38 | domain: str,
39 | info: dict,
40 | github_user: str,
41 | github_repo: str,
42 | main_branch: str = "master",
43 | stable_branch: str = "release/stable",
44 | ) -> str:
45 | def find_source() -> tuple[str, int, int]:
46 | # try to find the file and line number, based on code from numpy:
47 | # https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L286
48 | obj = sys.modules[info["module"]]
49 | for part in info["fullname"].split("."):
50 | obj = getattr(obj, part)
51 | fname = str(inspect.getsourcefile(obj))
52 | # https://github.com/rtfd/readthedocs.org/issues/5735
53 | if any(s in fname for s in ("readthedocs", "rtfd", "checkouts")):
54 | # /home/docs/checkouts/readthedocs.org/user_builds/pytorch_lightning/checkouts/
55 | # devel/pytorch_lightning/utilities/cls_experiment.py#L26-L176
56 | path_top = os.path.abspath(os.path.join("..", "..", ".."))
57 | fname = str(os.path.relpath(fname, start=path_top))
58 | else:
59 | # Local build, imitate master
60 | fname = f"{main_branch}/{os.path.relpath(fname, start=os.path.abspath('..'))}"
61 | source, line_start = inspect.getsourcelines(obj)
62 | return fname, line_start, line_start + len(source) - 1
63 |
64 | if domain != "py" or not info["module"]:
65 | return ""
66 | try:
67 | filename = "%s#L%d-L%d" % find_source() # noqa: UP031
68 | except Exception:
69 | filename = info["module"].replace(".", "/") + ".py"
70 | # import subprocess
71 | # tag = subprocess.Popen(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE,
72 | # universal_newlines=True).communicate()[0][:-1]
73 | branch = filename.split("/")[0]
74 | # do mapping from latest tags to master
75 | branch = {"latest": main_branch, "stable": stable_branch}.get(branch, branch)
76 | filename = "/".join([branch] + filename.split("/")[1:])
77 | return f"https://github.com/{github_user}/{github_repo}/blob/{filename}"
78 |
79 |
80 | def _load_pypi_versions(package_name: str) -> list[str]:
81 | """Load the versions of the package from PyPI.
82 |
83 | >>> _load_pypi_versions("numpy") # doctest: +ELLIPSIS
84 | ['0.9.6', '0.9.8', '1.0', ...]
85 | >>> _load_pypi_versions("scikit-learn") # doctest: +ELLIPSIS
86 | ['0.9', '0.10', '0.11', '0.12', ...]
87 |
88 | """
89 | from distutils.version import LooseVersion
90 |
91 | import requests
92 |
93 | url = f"https://pypi.org/pypi/{package_name}/json"
94 | data = requests.get(url, timeout=10).json()
95 | versions = data["releases"].keys()
96 | # filter all version which include only numbers and dots
97 | versions = {k for k in versions if re.match(r"^\d+(\.\d+)*$", k)}
98 | return sorted(versions, key=LooseVersion)
99 |
100 |
101 | def _update_link_based_imported_package(link: str, pkg_ver: str, version_digits: Optional[int]) -> str:
102 | """Adjust the linked external docs to be local.
103 |
104 | Args:
105 | link: the source link to be replaced
106 | pkg_ver: the target link to be replaced, if ``{package.version}`` is included it will be replaced accordingly
107 | version_digits: for semantic versioning, how many digits to be considered
108 |
109 | """
110 | pkg_att = pkg_ver.split(".")
111 | try:
112 | ver = _load_pypi_versions(pkg_att[0])[-1]
113 | except Exception:
114 | # load the package with all additional sub-modules
115 | module = importlib.import_module(".".join(pkg_att[:-1]))
116 | # load the attribute
117 | ver = getattr(module, pkg_att[0])
118 | # drop any additional context after `+`
119 | ver = ver.split("+")[0]
120 | # crop the version to the number of digits
121 | ver = ".".join(ver.split(".")[:version_digits])
122 | # replace the version
123 | return link.replace(f"{{{pkg_ver}}}", ver)
124 |
125 |
126 | def adjust_linked_external_docs(
127 | source_link: str,
128 | target_link: str,
129 | browse_folder: Union[str, Iterable[str]],
130 | file_extensions: Iterable[str] = (".rst", ".py"),
131 | version_digits: int = 2,
132 | ) -> None:
133 | r"""Adjust the linked external docs to be local.
134 |
135 | Args:
136 | source_link: the link to be replaced
137 | target_link: the link to be replaced, if ``{package.version}`` is included it will be replaced accordingly
138 | browse_folder: the location of the browsable folder
139 | file_extensions: what kind of files shall be scanned
140 | version_digits: for semantic versioning, how many digits to be considered
141 |
142 | Examples:
143 | >>> adjust_linked_external_docs(
144 | ... "https://numpy.org/doc/stable/",
145 | ... "https://numpy.org/doc/{numpy.__version__}/",
146 | ... "docs/source",
147 | ... )
148 |
149 | """
150 | list_files = []
151 | if isinstance(browse_folder, str):
152 | browse_folder = [browse_folder]
153 | for folder in browse_folder:
154 | for ext in file_extensions:
155 | list_files += glob.glob(os.path.join(folder, "**", f"*{ext}"), recursive=True)
156 | if not list_files:
157 | logging.warning(f'No files were listed in folder "{browse_folder}" and pattern "{file_extensions}"')
158 | return
159 |
160 | # find the expression for package version in {} brackets if any, use re to find it
161 | pkg_ver_all = re.findall(r"{(.+)}", target_link)
162 | for pkg_ver in pkg_ver_all:
163 | target_link = _update_link_based_imported_package(target_link, pkg_ver, version_digits)
164 |
165 | # replace the source link with target link
166 | for fpath in set(list_files):
167 | with open(fpath, encoding="UTF-8") as fopen:
168 | lines = fopen.readlines()
169 | found, skip = False, False
170 | for i, ln in enumerate(lines):
171 | # prevent the replacement its own function calls
172 | if f"{adjust_linked_external_docs.__name__}(" in ln:
173 | skip = True
174 | if not skip and source_link in ln:
175 | # replace the link if any found
176 | lines[i] = ln.replace(source_link, target_link)
177 | # record the found link for later write file
178 | found = True
179 | if skip and ")" in ln:
180 | skip = False
181 | if not found:
182 | continue
183 | logging.debug(f'links adjusting in {fpath}: "{source_link}" -> "{target_link}"')
184 | with open(fpath, "w", encoding="UTF-8") as fw:
185 | fw.writelines(lines)
186 |
--------------------------------------------------------------------------------
/src/lightning_utilities/docs/retriever.py:
--------------------------------------------------------------------------------
1 | # Licensed under the Apache License, Version 2.0 (the "License");
2 | # http://www.apache.org/licenses/LICENSE-2.0
3 | #
4 | import glob
5 | import logging
6 | import os
7 | import re
8 |
9 | import requests
10 |
11 |
12 | def _download_file(file_url: str, folder: str) -> str:
13 | """Download a file from URL to a particular folder."""
14 | fname = os.path.basename(file_url)
15 | file_path = os.path.join(folder, fname)
16 | if os.path.isfile(file_path):
17 | logging.warning(f'given file "{file_path}" already exists and will be overwritten with {file_url}')
18 | # see: https://stackoverflow.com/a/34957875
19 | rq = requests.get(file_url, timeout=10)
20 | with open(file_path, "wb") as outfile:
21 | outfile.write(rq.content)
22 | return fname
23 |
24 |
25 | def _search_all_occurrences(list_files: list[str], pattern: str) -> list[str]:
26 | """Search for all occurrences of specific pattern in a collection of files.
27 |
28 | Args:
29 | list_files: list of files to be scanned
30 | pattern: pattern for search, reg. expression
31 |
32 | """
33 | collected = []
34 | for file_path in list_files:
35 | with open(file_path, encoding="UTF-8") as fopem:
36 | body = fopem.read()
37 | found = re.findall(pattern, body)
38 | collected += found
39 | return collected
40 |
41 |
42 | def _replace_remote_with_local(file_path: str, docs_folder: str, pairs_url_path: list[tuple[str, str]]) -> None:
43 | """Replace all URL with local files in a given file.
44 |
45 | Args:
46 | file_path: file for replacement
47 | docs_folder: the location of docs related to the project root
48 | pairs_url_path: pairs of URL and local file path to be swapped
49 |
50 | """
51 | # drop the default/global path to the docs
52 | relt_path = os.path.dirname(file_path).replace(docs_folder, "")
53 | # filter the path starting with / as not empty folder names
54 | depth = len([p for p in relt_path.split(os.path.sep) if p])
55 | with open(file_path, encoding="UTF-8") as fopen:
56 | body = fopen.read()
57 | for url, fpath in pairs_url_path:
58 | if depth:
59 | path_up = [".."] * depth
60 | fpath = os.path.join(*path_up, fpath)
61 | body = body.replace(url, fpath)
62 | with open(file_path, "w", encoding="UTF-8") as fw:
63 | fw.write(body)
64 |
65 |
66 | def fetch_external_assets(
67 | docs_folder: str = "docs/source",
68 | assets_folder: str = "fetched-s3-assets",
69 | file_pattern: str = "*.rst",
70 | retrieve_pattern: str = r"https?://[-a-zA-Z0-9_]+\.s3\.[-a-zA-Z0-9()_\\+.\\/=]+",
71 | ) -> None:
72 | """Search all URL in docs, download these files locally and replace online with local version.
73 |
74 | Args:
75 | docs_folder: the location of docs related to the project root
76 | assets_folder: a folder inside ``docs_folder`` to be created and saving online assets
77 | file_pattern: what kind of files shall be scanned
78 | retrieve_pattern: pattern for reg. expression to search URL/S3 resources
79 |
80 | """
81 | list_files = glob.glob(os.path.join(docs_folder, "**", file_pattern), recursive=True)
82 | if not list_files:
83 | logging.warning(f'no files were listed in folder "{docs_folder}" and pattern "{file_pattern}"')
84 | return
85 |
86 | urls = _search_all_occurrences(list_files, pattern=retrieve_pattern)
87 | if not urls:
88 | logging.info(f"no resources/assets were match in {docs_folder} for {retrieve_pattern}")
89 | return
90 | target_folder = os.path.join(docs_folder, assets_folder)
91 | os.makedirs(target_folder, exist_ok=True)
92 | pairs_url_file = []
93 | for i, url in enumerate(set(urls)):
94 | logging.info(f" >> downloading ({i}/{len(urls)}): {url}")
95 | fname = _download_file(url, target_folder)
96 | pairs_url_file.append((url, os.path.join(assets_folder, fname)))
97 |
98 | for fpath in list_files:
99 | _replace_remote_with_local(fpath, docs_folder, pairs_url_file)
100 |
--------------------------------------------------------------------------------
/src/lightning_utilities/install/__init__.py:
--------------------------------------------------------------------------------
1 | """Generic Installation tools."""
2 |
3 | from lightning_utilities.install.requirements import Requirement, load_requirements
4 |
5 | __all__ = ["Requirement", "load_requirements"]
6 |
--------------------------------------------------------------------------------
/src/lightning_utilities/install/requirements.py:
--------------------------------------------------------------------------------
1 | # Licensed under the Apache License, Version 2.0 (the "License");
2 | # http://www.apache.org/licenses/LICENSE-2.0
3 | #
4 | import re
5 | from collections.abc import Iterable, Iterator
6 | from distutils.version import LooseVersion
7 | from pathlib import Path
8 | from typing import Any, Optional, Union
9 |
10 | from pkg_resources import Requirement, yield_lines # type: ignore[import-untyped]
11 |
12 |
13 | class _RequirementWithComment(Requirement):
14 | strict_string = "# strict"
15 |
16 | def __init__(self, *args: Any, comment: str = "", pip_argument: Optional[str] = None, **kwargs: Any) -> None:
17 | super().__init__(*args, **kwargs)
18 | self.comment = comment
19 | if not (pip_argument is None or pip_argument): # sanity check that it's not an empty str
20 | raise RuntimeError(f"wrong pip argument: {pip_argument}")
21 | self.pip_argument = pip_argument
22 | self.strict = self.strict_string in comment.lower()
23 |
24 | def adjust(self, unfreeze: str) -> str:
25 | """Remove version restrictions unless they are strict.
26 |
27 | >>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# anything").adjust("none")
28 | 'arrow<=1.2.2,>=1.2.0'
29 | >>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# strict").adjust("none")
30 | 'arrow<=1.2.2,>=1.2.0 # strict'
31 | >>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# my name").adjust("all")
32 | 'arrow>=1.2.0'
33 | >>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# strict").adjust("all")
34 | 'arrow<=1.2.2,>=1.2.0 # strict'
35 | >>> _RequirementWithComment("arrow").adjust("all")
36 | 'arrow'
37 | >>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# cool").adjust("major")
38 | 'arrow<2.0,>=1.2.0'
39 | >>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# strict").adjust("major")
40 | 'arrow<=1.2.2,>=1.2.0 # strict'
41 | >>> _RequirementWithComment("arrow>=1.2.0").adjust("major")
42 | 'arrow>=1.2.0'
43 | >>> _RequirementWithComment("arrow").adjust("major")
44 | 'arrow'
45 |
46 | """
47 | out = str(self)
48 | if self.strict:
49 | return f"{out} {self.strict_string}"
50 | if unfreeze == "major":
51 | for operator, version in self.specs:
52 | if operator in ("<", "<="):
53 | major = LooseVersion(version).version[0]
54 | # replace upper bound with major version increased by one
55 | return out.replace(f"{operator}{version}", f"<{int(major) + 1}.0")
56 | elif unfreeze == "all":
57 | for operator, version in self.specs:
58 | if operator in ("<", "<="):
59 | # drop upper bound
60 | return out.replace(f"{operator}{version},", "")
61 | elif unfreeze != "none":
62 | raise ValueError(f"Unexpected unfreeze: {unfreeze!r} value.")
63 | return out
64 |
65 |
66 | def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_RequirementWithComment]:
67 | r"""Adapted from `pkg_resources.parse_requirements` to include comments.
68 |
69 | >>> txt = ['# ignored', '', 'this # is an', '--piparg', 'example', 'foo # strict', 'thing', '-r different/file.txt']
70 | >>> [r.adjust('none') for r in _parse_requirements(txt)]
71 | ['this', 'example', 'foo # strict', 'thing']
72 | >>> txt = '\\n'.join(txt)
73 | >>> [r.adjust('none') for r in _parse_requirements(txt)]
74 | ['this', 'example', 'foo # strict', 'thing']
75 |
76 | """
77 | lines = yield_lines(strs)
78 | pip_argument = None
79 | for line in lines:
80 | # Drop comments -- a hash without a space may be in a URL.
81 | if " #" in line:
82 | comment_pos = line.find(" #")
83 | line, comment = line[:comment_pos], line[comment_pos:]
84 | else:
85 | comment = ""
86 | # If there is a line continuation, drop it, and append the next line.
87 | if line.endswith("\\"):
88 | line = line[:-2].strip()
89 | try:
90 | line += next(lines)
91 | except StopIteration:
92 | return
93 | # If there's a pip argument, save it
94 | if line.startswith("--"):
95 | pip_argument = line
96 | continue
97 | if line.startswith("-r "):
98 | # linked requirement files are unsupported
99 | continue
100 | if "@" in line or re.search("https?://", line):
101 | # skip lines with links like `pesq @ git+https://github.com/ludlows/python-pesq`
102 | continue
103 | yield _RequirementWithComment(line, comment=comment, pip_argument=pip_argument)
104 | pip_argument = None
105 |
106 |
107 | def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> list[str]:
108 | """Load requirements from a file.
109 |
110 | >>> import os
111 | >>> from lightning_utilities import _PROJECT_ROOT
112 | >>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
113 | >>> load_requirements(path_req, "docs.txt", unfreeze="major") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
114 | ['sphinx<6.0,>=4.0', ...]
115 |
116 | """
117 | if unfreeze not in {"none", "major", "all"}:
118 | raise ValueError(f'unsupported option of "{unfreeze}"')
119 | path = Path(path_dir) / file_name
120 | if not path.exists():
121 | raise FileNotFoundError(f"missing file for {(path_dir, file_name, path)}")
122 | text = path.read_text()
123 | return [req.adjust(unfreeze) for req in _parse_requirements(text)]
124 |
--------------------------------------------------------------------------------
/src/lightning_utilities/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/utilities/2658ff8b341fc0f30e9d9e9042a9d6c0e4de49a2/src/lightning_utilities/py.typed
--------------------------------------------------------------------------------
/src/lightning_utilities/test/__init__.py:
--------------------------------------------------------------------------------
1 | """General testing functionality."""
2 |
--------------------------------------------------------------------------------
/src/lightning_utilities/test/warning.py:
--------------------------------------------------------------------------------
1 | # Licensed under the Apache License, Version 2.0 (the "License");
2 | # http://www.apache.org/licenses/LICENSE-2.0
3 | #
4 | import re
5 | import warnings
6 | from collections.abc import Generator
7 | from contextlib import contextmanager
8 | from typing import Optional
9 |
10 |
11 | @contextmanager
12 | def no_warning_call(expected_warning: type[Warning] = Warning, match: Optional[str] = None) -> Generator:
13 | """Check that no warning was raised/emitted under this context manager."""
14 | with warnings.catch_warnings(record=True) as record:
15 | yield
16 |
17 | for w in record:
18 | if issubclass(w.category, expected_warning) and (match is None or re.compile(match).search(str(w.message))):
19 | raise AssertionError(f"`{expected_warning.__name__}` was raised: {w.message!r}")
20 |
--------------------------------------------------------------------------------
/tests/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | _PATH_HERE = os.path.dirname(__file__)
4 | _PATH_ROOT = os.path.dirname(os.path.dirname(_PATH_HERE))
5 | _PATH_SCRIPTS = os.path.join(_PATH_ROOT, "scripts")
6 |
--------------------------------------------------------------------------------
/tests/scripts/test_adjust_torch_versions.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 | import subprocess
4 | import sys
5 |
6 | from scripts import _PATH_SCRIPTS
7 |
8 | REQUIREMENTS_SAMPLE = """
9 | # This is sample requirements file
10 | # with multi line comments
11 |
12 | torchvision >=0.13.0, <0.16.0 # sample # comment
13 | gym[classic,control] >=0.17.0, <0.27.0
14 | ipython[all] <8.15.0 # strict
15 | torchmetrics >=0.10.0, <1.3.0
16 | deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" # strict
17 |
18 | """
19 | REQUIREMENTS_EXPECTED = """
20 | # This is sample requirements file
21 | # with multi line comments
22 |
23 | torchvision==0.11.1 # sample # comment
24 | gym[classic,control] >=0.17.0, <0.27.0
25 | ipython[all] <8.15.0 # strict
26 | torchmetrics >=0.10.0, <1.3.0
27 | deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" # strict
28 |
29 | """
30 |
31 |
32 | def test_adjust_torch_versions_call(tmp_path) -> None:
33 | path_script = os.path.join(_PATH_SCRIPTS, "adjust-torch-versions.py")
34 | path_req_file = str(tmp_path / "requirements.txt")
35 | with open(path_req_file, "w", encoding="utf8") as fopen:
36 | fopen.write(REQUIREMENTS_SAMPLE)
37 |
38 | return_code = subprocess.call([sys.executable, path_script, path_req_file, "1.10.0"]) # noqa: S603
39 | assert return_code == 0
40 |
41 | with open(path_req_file, encoding="utf8") as fopen:
42 | req_result = fopen.read()
43 | # ToDO: no idea why parsing lines on windows leave extra line after each line
44 | # tried strip, regex, hard-coded replace but none worked... so adjusting tests
45 | if platform.system() == "Windows":
46 | req_result = req_result.replace("\n\n", "\n")
47 | assert req_result == REQUIREMENTS_EXPECTED
48 |
--------------------------------------------------------------------------------
/tests/unittests/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | _PATH_UNITTESTS = os.path.dirname(__file__)
4 | _PATH_ROOT = os.path.dirname(os.path.dirname(_PATH_UNITTESTS))
5 |
--------------------------------------------------------------------------------
/tests/unittests/cli/test_dependencies.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from lightning_utilities.cli.dependencies import (
4 | prune_packages_in_requirements,
5 | replace_oldest_version,
6 | replace_package_in_requirements,
7 | )
8 |
9 | _PATH_ROOT = Path(__file__).parent.parent.parent
10 |
11 |
12 | def test_prune_packages(tmpdir):
13 | req_file = tmpdir / "requirements.txt"
14 | with open(req_file, "w") as fp:
15 | fp.writelines(["fire\n", "abc>=0.1\n"])
16 | prune_packages_in_requirements("abc", req_files=[str(req_file)])
17 | with open(req_file) as fp:
18 | lines = fp.readlines()
19 | assert lines == ["fire\n"]
20 |
21 |
22 | def test_oldest_packages(tmpdir):
23 | req_file = tmpdir / "requirements.txt"
24 | with open(req_file, "w") as fp:
25 | fp.writelines(["fire>0.2\n", "abc>=0.1\n"])
26 | replace_oldest_version(req_files=[str(req_file)])
27 | with open(req_file) as fp:
28 | lines = fp.readlines()
29 | assert lines == ["fire>0.2\n", "abc==0.1\n"]
30 |
31 |
32 | def test_replace_packages(tmpdir):
33 | req_file = tmpdir / "requirements.txt"
34 | with open(req_file, "w") as fp:
35 | fp.writelines(["torchvision>=0.2\n", "torch>=1.0 # comment\n", "torchtext <0.3\n"])
36 | replace_package_in_requirements(old_package="torch", new_package="pytorch", req_files=[str(req_file)])
37 | with open(req_file) as fp:
38 | lines = fp.readlines()
39 | assert lines == ["torchvision>=0.2\n", "pytorch>=1.0 # comment\n", "torchtext <0.3\n"]
40 |
--------------------------------------------------------------------------------
/tests/unittests/conftest.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 |
3 | import os
4 | import shutil
5 | import tempfile
6 | from pathlib import Path
7 |
8 | import pytest
9 |
10 | from unittests import _PATH_ROOT
11 |
12 | _PATH_DOCS = os.path.join(_PATH_ROOT, "docs", "source")
13 |
14 |
15 | @pytest.fixture(scope="session")
16 | def temp_docs():
17 | """Create a dummy documentation folder."""
18 | # create a folder for docs
19 | docs_folder = Path(tempfile.mkdtemp())
20 | # copy all real docs from _PATH_DOCS to local temp_docs
21 | for root, _, files in os.walk(_PATH_DOCS):
22 | for file in files:
23 | fpath = os.path.join(root, file)
24 | temp_path = docs_folder / os.path.relpath(fpath, _PATH_DOCS)
25 | temp_path.parent.mkdir(exist_ok=True, parents=True)
26 | with open(fpath, "rb") as fopen:
27 | temp_path.write_bytes(fopen.read())
28 | yield str(docs_folder)
29 | # remove the folder
30 | shutil.rmtree(docs_folder.parent, ignore_errors=True)
31 |
--------------------------------------------------------------------------------
/tests/unittests/core/test_enums.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | from lightning_utilities.core.enums import StrEnum
4 |
5 |
6 | def test_consistency():
7 | class MyEnum(StrEnum):
8 | FOO = "FOO"
9 | BAR = "BAR"
10 | BAZ = "BAZ"
11 | NUM = "32"
12 |
13 | # normal equality, case invariant
14 | assert MyEnum.FOO == "FOO"
15 | assert MyEnum.FOO == "foo"
16 |
17 | # int support
18 | assert MyEnum.NUM == 32
19 | assert MyEnum.NUM in (32, "32")
20 |
21 | # key-based
22 | assert MyEnum.from_str("num") == MyEnum.NUM
23 |
24 | # collections
25 | assert MyEnum.BAZ not in ("FOO", "BAR")
26 | assert MyEnum.BAZ in ("FOO", "BAZ")
27 | assert MyEnum.BAZ in ("baz", "FOO")
28 | assert MyEnum.BAZ not in {"BAR", "FOO"}
29 | # hash cannot be case invariant
30 | assert MyEnum.BAZ not in {"BAZ", "FOO"}
31 | assert MyEnum.BAZ in {"baz", "FOO"}
32 |
33 |
34 | def test_comparison_with_other_enum():
35 | class MyEnum(StrEnum):
36 | FOO = "FOO"
37 |
38 | class OtherEnum(Enum):
39 | FOO = 123
40 |
41 | assert not MyEnum.FOO.__eq__(OtherEnum.FOO)
42 |
43 |
44 | def test_create_from_string():
45 | class MyEnum(StrEnum):
46 | t1 = "T/1"
47 | T2 = "t:2"
48 |
49 | assert MyEnum.from_str("T1", source="key")
50 | assert MyEnum.try_from_str("T1", source="value") is None
51 | assert MyEnum.from_str("T1", source="any")
52 |
53 | assert MyEnum.try_from_str("T:2", source="key") is None
54 | assert MyEnum.from_str("T:2", source="value")
55 | assert MyEnum.from_str("T:2", source="any")
56 |
--------------------------------------------------------------------------------
/tests/unittests/core/test_imports.py:
--------------------------------------------------------------------------------
1 | import operator
2 | import re
3 | from unittest import mock
4 | from unittest.mock import Mock
5 |
6 | import pytest
7 |
8 | from lightning_utilities.core.imports import (
9 | RequirementCache,
10 | compare_version,
11 | get_dependency_min_version_spec,
12 | lazy_import,
13 | module_available,
14 | requires,
15 | )
16 |
17 | try:
18 | from importlib.metadata import PackageNotFoundError
19 | except ImportError:
20 | # Python < 3.8
21 | from importlib_metadata import PackageNotFoundError
22 |
23 |
24 | def test_module_exists():
25 | assert module_available("_pytest")
26 | assert module_available("_pytest.mark.structures")
27 | assert not module_available("_pytest.mark.asdf")
28 | assert not module_available("asdf")
29 | assert not module_available("asdf.bla.asdf")
30 |
31 |
32 | def testcompare_version(monkeypatch):
33 | monkeypatch.setattr(pytest, "__version__", "1.8.9")
34 | assert not compare_version("pytest", operator.ge, "1.10.0")
35 | assert compare_version("pytest", operator.lt, "1.10.0")
36 |
37 | monkeypatch.setattr(pytest, "__version__", "1.10.0.dev123")
38 | assert compare_version("pytest", operator.ge, "1.10.0.dev123")
39 | assert not compare_version("pytest", operator.ge, "1.10.0.dev124")
40 |
41 | assert compare_version("pytest", operator.ge, "1.10.0.dev123", use_base_version=True)
42 | assert compare_version("pytest", operator.ge, "1.10.0.dev124", use_base_version=True)
43 |
44 | monkeypatch.setattr(pytest, "__version__", "1.10.0a0+0aef44c") # dev version before rc
45 | assert compare_version("pytest", operator.ge, "1.10.0.rc0", use_base_version=True)
46 | assert not compare_version("pytest", operator.ge, "1.10.0.rc0")
47 | assert compare_version("pytest", operator.ge, "1.10.0", use_base_version=True)
48 | assert not compare_version("pytest", operator.ge, "1.10.0")
49 |
50 |
51 | def test_requirement_cache():
52 | assert RequirementCache(f"pytest>={pytest.__version__}")
53 | assert not RequirementCache(f"pytest<{pytest.__version__}")
54 | assert "pip install -U 'not-found-requirement'" in str(RequirementCache("not-found-requirement"))
55 |
56 | # invalid requirement is skipped by valid module
57 | assert RequirementCache(f"pytest<{pytest.__version__}", "pytest")
58 |
59 | cache = RequirementCache("this_module_is_not_installed")
60 | assert not cache
61 | assert "pip install -U 'this_module_is_not_installed" in str(cache)
62 |
63 | cache = RequirementCache("this_module_is_not_installed", "this_also_is_not")
64 | assert not cache
65 | assert "pip install -U 'this_module_is_not_installed" in str(cache)
66 |
67 | cache = RequirementCache("pytest[not-valid-extra]")
68 | assert not cache
69 | assert "pip install -U 'pytest[not-valid-extra]" in str(cache)
70 |
71 |
72 | @mock.patch("lightning_utilities.core.imports.Requirement")
73 | @mock.patch("lightning_utilities.core.imports._version")
74 | @mock.patch("lightning_utilities.core.imports.distribution")
75 | def test_requirement_cache_with_extras(distribution_mock, version_mock, requirement_mock):
76 | requirement_mock().specifier.contains.return_value = True
77 | requirement_mock().name = "jsonargparse"
78 | requirement_mock().extras = []
79 | version_mock.return_value = "1.0.0"
80 | assert RequirementCache("jsonargparse>=1.0.0")
81 |
82 | with mock.patch("lightning_utilities.core.imports.RequirementCache._get_extra_requirements") as get_extra_req_mock:
83 | get_extra_req_mock.return_value = [
84 | # Extra packages, all versions satisfied
85 | Mock(name="extra_package1", specifier=Mock(contains=Mock(return_value=True))),
86 | Mock(name="extra_package2", specifier=Mock(contains=Mock(return_value=True))),
87 | ]
88 | distribution_mock.return_value = Mock(version="0.10.0")
89 | requirement_mock().extras = ["signatures"]
90 | assert RequirementCache("jsonargparse[signatures]>=1.0.0")
91 |
92 | with mock.patch("lightning_utilities.core.imports.RequirementCache._get_extra_requirements") as get_extra_req_mock:
93 | get_extra_req_mock.return_value = [
94 | # Extra packages, but not all versions are satisfied
95 | Mock(name="extra_package1", specifier=Mock(contains=Mock(return_value=True))),
96 | Mock(name="extra_package2", specifier=Mock(contains=Mock(return_value=False))),
97 | ]
98 | distribution_mock.return_value = Mock(version="0.10.0")
99 | requirement_mock().extras = ["signatures"]
100 | assert not RequirementCache("jsonargparse[signatures]>=1.0.0")
101 |
102 |
103 | @mock.patch("lightning_utilities.core.imports._version")
104 | def test_requirement_cache_with_prerelease_package(version_mock):
105 | version_mock.return_value = "0.11.0"
106 | assert RequirementCache("transformer-engine>=0.11.0")
107 | version_mock.return_value = "0.11.0.dev0+931b44f"
108 | assert not RequirementCache("transformer-engine>=0.11.0")
109 | version_mock.return_value = "1.10.0.dev0+931b44f"
110 | assert RequirementCache("transformer-engine>=0.11.0")
111 |
112 |
113 | def test_module_available_cache():
114 | assert RequirementCache(module="pytest")
115 | assert not RequirementCache(module="this_module_is_not_installed")
116 | assert "pip install -U this_module_is_not_installed" in str(RequirementCache(module="this_module_is_not_installed"))
117 |
118 |
119 | def test_get_dependency_min_version_spec():
120 | attrs_min_version_spec = get_dependency_min_version_spec("pytest", "attrs")
121 | assert re.match(r"^>=[\d.]+$", attrs_min_version_spec)
122 |
123 | with pytest.raises(ValueError, match="'invalid' not found in package 'pytest'"):
124 | get_dependency_min_version_spec("pytest", "invalid")
125 |
126 | with pytest.raises(PackageNotFoundError, match="invalid"):
127 | get_dependency_min_version_spec("invalid", "invalid")
128 |
129 |
130 | def test_lazy_import():
131 | def callback_fcn():
132 | raise ValueError
133 |
134 | math = lazy_import("math", callback=callback_fcn)
135 | with pytest.raises(ValueError, match=""): # noqa: PT011
136 | math.floor(5.1)
137 |
138 | module = lazy_import("asdf")
139 | with pytest.raises(ModuleNotFoundError, match="No module named 'asdf'"):
140 | print(module)
141 |
142 | os = lazy_import("os")
143 | assert os.getcwd()
144 |
145 |
146 | @requires("torch.unknown.subpackage")
147 | def my_torch_func(i: int) -> int:
148 | import torch # noqa
149 |
150 | return i
151 |
152 |
153 | def test_torch_func_raised():
154 | with pytest.raises(
155 | ModuleNotFoundError,
156 | match="Required dependencies not available: \nModule not found: 'torch.unknown.subpackage'. ",
157 | ):
158 | my_torch_func(42)
159 |
160 |
161 | @requires("random")
162 | def my_random_func(nb: int) -> int:
163 | from random import randint
164 |
165 | return randint(0, nb)
166 |
167 |
168 | def test_rand_func_passed():
169 | assert 0 <= my_random_func(42) <= 42
170 |
171 |
172 | class MyTorchClass:
173 | @requires("torch>99.0", "random")
174 | def __init__(self):
175 | from random import randint
176 |
177 | import torch # noqa
178 |
179 | self._rnd = randint(1, 9)
180 |
181 |
182 | def test_torch_class_raised():
183 | with pytest.raises(
184 | ModuleNotFoundError, match="Required dependencies not available: \nModule not found: 'torch>99.0'."
185 | ):
186 | MyTorchClass()
187 |
188 |
189 | class MyRandClass:
190 | @requires("random")
191 | def __init__(self, nb: int):
192 | from random import randint
193 |
194 | self._rnd = randint(1, nb)
195 |
196 |
197 | def test_rand_class_passed():
198 | cls = MyRandClass(42)
199 | assert 0 <= cls._rnd <= 42
200 |
--------------------------------------------------------------------------------
/tests/unittests/core/test_inheritance.py:
--------------------------------------------------------------------------------
1 | from lightning_utilities.core.inheritance import get_all_subclasses
2 |
3 |
4 | def test_get_all_subclasses():
5 | class A1: ...
6 |
7 | class A2(A1): ...
8 |
9 | class B1: ...
10 |
11 | class B2(B1): ...
12 |
13 | class C(A2, B2): ...
14 |
15 | assert get_all_subclasses(A1) == {A2, C}
16 | assert get_all_subclasses(A2) == {C}
17 | assert get_all_subclasses(B1) == {B2, C}
18 | assert get_all_subclasses(B2) == {C}
19 | assert get_all_subclasses(C) == set()
20 |
--------------------------------------------------------------------------------
/tests/unittests/core/test_overrides.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | from functools import partial, wraps
3 | from typing import Any, Callable
4 | from unittest.mock import Mock
5 |
6 | import pytest
7 |
8 | from lightning_utilities.core.overrides import is_overridden
9 |
10 |
11 | class LightningModule:
12 | def training_step(self): ...
13 |
14 |
15 | class BoringModel(LightningModule):
16 | def training_step(self): ...
17 |
18 |
19 | class Strategy:
20 | @contextmanager
21 | def model_sharded_context(): ...
22 |
23 |
24 | class SingleDeviceStrategy(Strategy): ...
25 |
26 |
27 | def test_is_overridden():
28 | assert not is_overridden("whatever", object(), parent=LightningModule)
29 |
30 | class TestModel(BoringModel):
31 | def foo(self):
32 | pass
33 |
34 | def bar(self):
35 | return 1
36 |
37 | with pytest.raises(ValueError, match="The parent should define the method"):
38 | is_overridden("foo", TestModel(), parent=BoringModel)
39 |
40 | # normal usage
41 | assert is_overridden("training_step", BoringModel(), parent=LightningModule)
42 |
43 | # reversed. works even without inheritance
44 | assert is_overridden("training_step", LightningModule(), parent=BoringModel)
45 |
46 | class WrappedModel(TestModel):
47 | def __new__(cls, *args: Any, **kwargs: Any):
48 | obj = super().__new__(cls)
49 | obj.foo = cls.wrap(obj.foo)
50 | obj.bar = cls.wrap(obj.bar)
51 | return obj
52 |
53 | @staticmethod
54 | def wrap(fn) -> Callable:
55 | @wraps(fn)
56 | def wrapper():
57 | fn()
58 |
59 | return wrapper
60 |
61 | def bar(self):
62 | return 2
63 |
64 | # `functools.wraps()` support
65 | assert not is_overridden("foo", WrappedModel(), parent=TestModel)
66 | assert is_overridden("bar", WrappedModel(), parent=TestModel)
67 |
68 | # `Mock` support
69 | mock = Mock(spec=BoringModel, wraps=BoringModel())
70 | assert is_overridden("training_step", mock, parent=LightningModule)
71 |
72 | # `partial` support
73 | model = BoringModel()
74 | model.training_step = partial(model.training_step)
75 | assert is_overridden("training_step", model, parent=LightningModule)
76 |
77 | # `@contextmanager` support
78 | assert not is_overridden("model_sharded_context", SingleDeviceStrategy(), Strategy)
79 |
--------------------------------------------------------------------------------
/tests/unittests/core/test_rank_zero.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
4 |
5 |
6 | def test_rank_zero_only_raises():
7 | foo = rank_zero_only(lambda x: x + 1)
8 | with pytest.raises(RuntimeError, match="rank_zero_only.rank` needs to be set "):
9 | foo(1)
10 |
11 |
12 | @pytest.mark.parametrize("rank", [0, 1, 4])
13 | def test_rank_prefixed_message(rank):
14 | rank_zero_only.rank = rank
15 | message = rank_prefixed_message("bar", rank)
16 | assert message == f"[rank: {rank}] bar"
17 | # reset
18 | del rank_zero_only.rank
19 |
20 |
21 | def test_rank_zero_only_default():
22 | foo = lambda: "foo"
23 | rank_zero_foo = rank_zero_only(foo, "not foo")
24 |
25 | rank_zero_only.rank = 0
26 | assert rank_zero_foo() == "foo"
27 |
28 | rank_zero_only.rank = 1
29 | assert rank_zero_foo() == "not foo"
30 |
--------------------------------------------------------------------------------
/tests/unittests/docs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/utilities/2658ff8b341fc0f30e9d9e9042a9d6c0e4de49a2/tests/unittests/docs/__init__.py
--------------------------------------------------------------------------------
/tests/unittests/docs/test_formatting.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import re
3 |
4 | import pytest
5 |
6 | from lightning_utilities.docs import adjust_linked_external_docs
7 |
8 |
9 | @pytest.mark.online
10 | def test_adjust_linked_external_docs(temp_docs):
11 | # take config as it includes API references with `stable`
12 | path_conf = os.path.join(temp_docs, "conf.py")
13 | path_testpage = os.path.join(temp_docs, "test-page.rst")
14 |
15 | def _get_line_with_numpy(path_rst: str, pattern: str) -> str:
16 | with open(path_rst, encoding="UTF-8") as fopen:
17 | lines = fopen.readlines()
18 | # find the first line with figure reference
19 | return next(ln for ln in lines if pattern in ln)
20 |
21 | # validate the initial expectations
22 | line = _get_line_with_numpy(path_conf, pattern='"numpy":')
23 | assert "https://numpy.org/doc/stable/" in line
24 | line = _get_line_with_numpy(path_testpage, pattern="Link to scikit-learn stable documentation:")
25 | assert "https://scikit-learn.org/stable/" in line
26 |
27 | adjust_linked_external_docs(
28 | "https://numpy.org/doc/stable/", "https://numpy.org/doc/{numpy.__version__}/", temp_docs
29 | )
30 | adjust_linked_external_docs(
31 | "https://scikit-learn.org/stable/", "https://scikit-learn.org/{scikit-learn}/", temp_docs
32 | )
33 |
34 | # validate the final state of index page
35 | line = _get_line_with_numpy(path_conf, pattern='"numpy":')
36 | assert re.search(r"https://numpy.org/doc/([1-9]\d*)\.(\d+)/", line)
37 | line = _get_line_with_numpy(path_testpage, pattern="Link to scikit-learn stable documentation:")
38 | assert re.search(r"https://scikit-learn.org/([1-9]\d*)\.(\d+)/", line)
39 |
--------------------------------------------------------------------------------
/tests/unittests/docs/test_retriever.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import shutil
3 |
4 | import pytest
5 |
6 | from lightning_utilities.docs import fetch_external_assets
7 |
8 |
9 | @pytest.mark.online
10 | def test_retriever_s3(temp_docs):
11 | # take the index page
12 | path_index = os.path.join(temp_docs, "index.rst")
13 | # copy it to another location to test depth
14 | path_page = os.path.join(temp_docs, "any", "extra", "page.rst")
15 | os.makedirs(os.path.dirname(path_page), exist_ok=True)
16 | shutil.copy(path_index, path_page)
17 |
18 | def _get_line_with_figure(path_rst: str) -> str:
19 | with open(path_rst, encoding="UTF-8") as fopen:
20 | lines = fopen.readlines()
21 | # find the first line with figure reference
22 | return next(ln for ln in lines if ln.startswith(".. figure::"))
23 |
24 | # validate the initial expectations
25 | line = _get_line_with_figure(path_index)
26 | # that the image exists
27 | assert "Lightning.gif" in line
28 | # and it is sourced in S3
29 | assert ".s3." in line
30 |
31 | fetch_external_assets(docs_folder=temp_docs)
32 |
33 | # validate the final state of index page
34 | line = _get_line_with_figure(path_index)
35 | # that the image exists
36 | assert os.path.join("fetched-s3-assets", "Lightning.gif") in line
37 | # but it is not sourced from S3
38 | assert ".s3." not in line
39 |
40 | # validate the final state of additional page
41 | line = _get_line_with_figure(path_page)
42 | # that the image exists in the proper depth
43 | assert os.path.join("..", "..", "fetched-s3-assets", "Lightning.gif") in line
44 | # but it is not sourced from S3
45 | assert ".s3." not in line
46 |
--------------------------------------------------------------------------------
/tests/unittests/mocks.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Iterable
2 | from typing import Any
3 |
4 | from lightning_utilities.core.imports import package_available
5 |
6 | if package_available("torch"):
7 | import torch
8 | else:
9 | # minimal torch implementation to avoid installing torch in testing CI
10 | class TensorMock:
11 | def __init__(self, data) -> None:
12 | self.data = data
13 |
14 | def __add__(self, other):
15 | """Perform and operation."""
16 | if isinstance(self.data, Iterable):
17 | if isinstance(other, (int, float)):
18 | return TensorMock([a + other for a in self.data])
19 | if isinstance(other, Iterable):
20 | return TensorMock([a + b for a, b in zip(self, other)])
21 | return self.data + other
22 |
23 | def __mul__(self, other):
24 | """Perform mul operation."""
25 | if isinstance(self.data, Iterable):
26 | if isinstance(other, (int, float)):
27 | return TensorMock([a * other for a in self.data])
28 | if isinstance(other, Iterable):
29 | return TensorMock([a * b for a, b in zip(self, other)])
30 | return self.data * other
31 |
32 | def __iter__(self):
33 | """Iterate."""
34 | return iter(self.data)
35 |
36 | def __repr__(self) -> str:
37 | """Return object representation."""
38 | return repr(self.data)
39 |
40 | def __eq__(self, other):
41 | """Perform equal operation."""
42 | return self.data == other
43 |
44 | def add_(self, value):
45 | self.data += value
46 | return self.data
47 |
48 | class TorchMock:
49 | Tensor = TensorMock
50 |
51 | @staticmethod
52 | def tensor(data: Any) -> TensorMock:
53 | return TensorMock(data)
54 |
55 | @staticmethod
56 | def equal(a: Any, b: Any) -> bool:
57 | return a == b
58 |
59 | @staticmethod
60 | def arange(*args: Any) -> TensorMock:
61 | return TensorMock(list(range(*args)))
62 |
63 | torch = TorchMock()
64 |
--------------------------------------------------------------------------------
/tests/unittests/test/test_warnings.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from re import escape
3 |
4 | import pytest
5 |
6 | from lightning_utilities.test.warning import no_warning_call
7 |
8 |
9 | def test_no_warning_call():
10 | with no_warning_call():
11 | ...
12 |
13 | with pytest.raises(AssertionError, match=escape("`Warning` was raised: UserWarning('foo')")), no_warning_call():
14 | warnings.warn("foo")
15 |
16 | with no_warning_call(DeprecationWarning):
17 | warnings.warn("foo")
18 |
19 | class MyDeprecationWarning(DeprecationWarning): ...
20 |
21 | with (
22 | pytest.raises(AssertionError, match=escape("`DeprecationWarning` was raised: MyDeprecationWarning('bar')")),
23 | no_warning_call(DeprecationWarning),
24 | ):
25 | warnings.warn("bar", category=MyDeprecationWarning)
26 |
--------------------------------------------------------------------------------