├── .codecov.yml
├── .github
├── CODEOWNERS
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ ├── documentation.md
│ └── feature_request.md
├── PULL_REQUEST_TEMPLATE.md
├── dependabot.yml
├── stale.yml
└── workflows
│ ├── ci-checks.yml
│ ├── ci-cloud.yml
│ ├── ci-testing.yml
│ ├── cleanup-caches.yml
│ ├── docs-build.yml
│ ├── greetings.yml
│ ├── label-conflicts.yml
│ └── release-pypi.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yml
├── CHANGELOG.md
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── _requirements
├── _ci.txt
├── _docs.txt
├── extra.txt
├── test.txt
└── typing.txt
├── docs
├── .build_docs.sh
├── Makefile
├── make.bat
└── source
│ ├── _static
│ ├── copybutton.js
│ └── images
│ │ ├── icon.svg
│ │ ├── logo-large.svg
│ │ ├── logo-small.svg
│ │ ├── logo.png
│ │ └── logo.svg
│ ├── _templates
│ └── theme_variables.jinja
│ ├── conf.py
│ └── index.rst
├── examples
├── demo-tensorflow-keras.py
├── demo-upload-download.py
├── resume-lightning-training.py
├── train-model-and-simple-save.py
├── train-model-with-lightning-callback.py
└── train-model-with-lightning-logger.py
├── pyproject.toml
├── requirements.txt
├── setup.py
├── src
└── litmodels
│ ├── __about__.py
│ ├── __init__.py
│ ├── demos
│ └── __init__.py
│ ├── integrations
│ ├── __init__.py
│ ├── checkpoints.py
│ ├── duplicate.py
│ ├── imports.py
│ └── mixins.py
│ └── io
│ ├── __init__.py
│ ├── cloud.py
│ ├── gateway.py
│ └── utils.py
└── tests
├── __init__.py
├── conftest.py
├── integrations
├── __init__.py
├── test_checkpoints.py
├── test_duplicate.py
├── test_mixins.py
└── test_real_cloud.py
└── test_io_cloud.py
/.codecov.yml:
--------------------------------------------------------------------------------
1 | # see https://docs.codecov.io/docs/codecov-yaml
2 | # Validation check:
3 | # $ curl --data-binary @.codecov.yml https://codecov.io/validate
4 |
5 | # https://docs.codecov.io/docs/codecovyml-reference
6 | codecov:
7 | bot: "codecov-io"
8 | strict_yaml_branch: "yaml-config"
9 | require_ci_to_pass: yes
10 | notify:
11 | # after_n_builds: 2
12 | wait_for_ci: yes
13 |
14 | coverage:
15 | precision: 0 # 2 = xx.xx%, 0 = xx%
16 | round: nearest # how coverage is rounded: down/up/nearest
17 | range: 40...100 # custom range of coverage colors from red -> yellow -> green
18 | status:
19 | # https://codecov.readme.io/v1.0/docs/commit-status
20 | project:
21 | default:
22 | informational: true
23 | target: 99% # specify the target coverage for each commit status
24 | threshold: 30% # allow this little decrease on project
25 | # https://github.com/codecov/support/wiki/Filtering-Branches
26 | # branches: main
27 | if_ci_failed: error
28 | # https://github.com/codecov/support/wiki/Patch-Status
29 | patch:
30 | default:
31 | informational: true
32 | target: 50% # specify the target "X%" coverage to hit
33 | # threshold: 50% # allow this much decrease on patch
34 | changes: false
35 |
36 | # https://docs.codecov.com/docs/github-checks#disabling-github-checks-patch-annotations
37 | github_checks:
38 | annotations: false
39 |
40 | parsers:
41 | gcov:
42 | branch_detection:
43 | conditional: true
44 | loop: true
45 | macro: false
46 | method: false
47 | javascript:
48 | enable_partials: false
49 |
50 | comment:
51 | layout: header, diff
52 | require_changes: false
53 | behavior: default # update if exists else create new
54 | # branches: *
55 |
--------------------------------------------------------------------------------
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Each line is a file pattern followed by one or more owners.
2 |
3 | # These owners will be the default owners for everything in the repo. Unless a later match takes precedence,
4 | # @global-owner1 and @global-owner2 will be requested for review when someone opens a pull request.
5 | * @borda @ethanwharris @justusschock
6 |
7 | # CI/CD and configs
8 | /.github/ @borda
9 | *.yml @borda
10 |
11 | /README.md @williamfalcon
12 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: bug, help wanted
6 | assignees: ''
7 | ---
8 |
9 | ## 🐛 Bug
10 |
11 |
12 |
13 | ### To Reproduce
14 |
15 | Steps to reproduce the behavior:
16 |
17 | 1. Go to '...'
18 | 2. Run '....'
19 | 3. Scroll down to '....'
20 | 4. See error
21 |
22 |
23 |
24 | #### Code sample
25 |
26 |
28 |
29 | ### Expected behavior
30 |
31 |
32 |
33 | ### Environment
34 |
35 | - PyTorch Version (e.g., 1.0):
36 | - OS (e.g., Linux):
37 | - How you installed PyTorch (`conda`, `pip`, source):
38 | - Build command you used (if compiling from source):
39 | - Python version:
40 | - CUDA/cuDNN version:
41 | - GPU models and configuration:
42 | - Any other relevant information:
43 |
44 | ### Additional context
45 |
46 |
47 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/documentation.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Typos and doc fixes
3 | about: Typos and doc fixes
4 | title: ''
5 | labels: documentation
6 | assignees: ''
7 | ---
8 |
9 | ## 📚 Documentation
10 |
11 | For typos and doc fixes, please go ahead and:
12 |
13 | 1. Create an issue.
14 | 2. Fix the typo.
15 | 3. Submit a PR.
16 |
17 | Thanks!
18 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: enhancement, help wanted
6 | assignees: ''
7 | ---
8 |
9 | ## 🚀 Feature
10 |
11 |
12 |
13 | ### Motivation
14 |
15 |
16 |
17 | ### Pitch
18 |
19 |
20 |
21 | ### Alternatives
22 |
23 |
24 |
25 | ### Additional context
26 |
27 |
28 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 |
2 | Before submitting
3 |
4 | - [ ] Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
5 | - [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/pytorch-lightning/blob/main/.github/CONTRIBUTING.md), Pull Request section?
6 | - [ ] Did you make sure to update the docs?
7 | - [ ] Did you write any new necessary tests?
8 |
9 |
10 |
11 | ## What does this PR do?
12 |
13 | Fixes # (issue).
14 |
15 | ## PR review
16 |
17 | Anyone in the community is free to review the PR once the tests have passed.
18 | If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged.
19 |
20 | ## Did you have fun?
21 |
22 | Make sure you had fun coding 🙃
23 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # Basic dependabot.yml file with
2 | # minimum configuration for two package managers
3 |
4 | version: 2
5 | updates:
6 | # Enable version updates for python
7 | - package-ecosystem: "pip"
8 | # Look for a `requirements` in the `root` directory
9 | directory: "/"
10 | # Check for updates once a week
11 | schedule:
12 | interval: "daily"
13 | # Labels on pull requests for version updates only
14 | labels:
15 | - "dependencies"
16 | pull-request-branch-name:
17 | # Separate sections of the branch name with a hyphen
18 | # for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1`
19 | separator: "-"
20 | # Allow up to 5 open pull requests for pip dependencies
21 | open-pull-requests-limit: 5
22 |
23 | # Enable version updates for GitHub Actions
24 | - package-ecosystem: "github-actions"
25 | directory: "/"
26 | # Check for updates once a week
27 | schedule:
28 | interval: "monthly"
29 | # Labels on pull requests for version updates only
30 | labels: ["CI"]
31 | pull-request-branch-name:
32 | # Separate sections of the branch name with a hyphen for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1`
33 | separator: "-"
34 | # Allow up to 5 open pull requests for GitHub Actions
35 | open-pull-requests-limit: 5
36 | groups:
37 | GHA-updates:
38 | patterns:
39 | - "*"
40 |
--------------------------------------------------------------------------------
/.github/stale.yml:
--------------------------------------------------------------------------------
1 | # https://github.com/marketplace/stale
2 |
3 | # Number of days of inactivity before an issue becomes stale
4 | daysUntilStale: 60
5 | # Number of days of inactivity before a stale issue is closed
6 | daysUntilClose: 14
7 | # Issues with these labels will never be considered stale
8 | exemptLabels:
9 | - pinned
10 | - security
11 | # Label to use when marking an issue as stale
12 | staleLabel: won't fix
13 | # Comment to post when marking an issue as stale. Set to `false` to disable
14 | markComment: >
15 | This issue has been automatically marked as stale because it has not had
16 | recent activity. It will be closed if no further activity occurs. Thank you
17 | for your contributions.
18 | # Comment to post when closing a stale issue. Set to `false` to disable
19 | closeComment: false
20 |
21 | # Set to true to ignore issues in a project (defaults to false)
22 | exemptProjects: true
23 | # Set to true to ignore issues in a milestone (defaults to false)
24 | exemptMilestones: true
25 | # Set to true to ignore issues with an assignee (defaults to false)
26 | exemptAssignees: true
27 |
--------------------------------------------------------------------------------
/.github/workflows/ci-checks.yml:
--------------------------------------------------------------------------------
1 | name: General checks
2 |
3 | on:
4 | push:
5 | branches: [main, "release/*"]
6 | pull_request:
7 | branches: [main, "release/*"]
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
11 | cancel-in-progress: ${{ ! (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
12 |
13 | jobs:
14 | #check-precommit:
15 | # uses: Lightning-AI/utilities/.github/workflows/check-precommit.yml@main
16 |
17 | check-typing:
18 | uses: Lightning-AI/utilities/.github/workflows/check-typing.yml@main
19 | with:
20 | actions-ref: main
21 | extra-typing: "typing"
22 |
23 | check-schema:
24 | uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@main
25 | with:
26 | azure-dir: ""
27 |
28 | check-package:
29 | uses: Lightning-AI/utilities/.github/workflows/check-package.yml@main
30 | with:
31 | actions-ref: main
32 | import-name: "litmodels"
33 | artifact-name: dist-packages-${{ github.sha }}
34 | testing-matrix: |
35 | {
36 | "os": ["ubuntu-latest", "macos-latest", "windows-latest"],
37 | "python-version": ["3.9", "3.12"]
38 | }
39 |
40 | # check-docs:
41 | # uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@main
42 | # with:
43 | # requirements-file: "_requirements/_docs.txt"
44 |
--------------------------------------------------------------------------------
/.github/workflows/ci-cloud.yml:
--------------------------------------------------------------------------------
1 | name: Cloud integration
2 |
3 | # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows
4 | on:
5 | push:
6 | workflow_dispatch:
7 | schedule:
8 | - cron: "0 0 * * *"
9 |
10 | defaults:
11 | run:
12 | shell: bash
13 |
14 | jobs:
15 | access-secrets:
16 | # try to access secrets to ensure they are available
17 | # if not set job output to be false
18 | runs-on: ubuntu-latest
19 | outputs:
20 | secrets_available: ${{ steps.check_secrets.outputs.secrets_available }}
21 | steps:
22 | - name: Check secrets
23 | id: check_secrets
24 | run: |
25 | if [[ -z "${{ secrets.LIGHTNING_USER_ID }}" || -z "${{ secrets.LIGHTNING_API_KEY }}" ]]; then
26 | echo "Secrets are not set. Exiting..."
27 | echo "secrets_available=false" >> $GITHUB_OUTPUT
28 | else
29 | echo "Secrets are available."
30 | echo "secrets_available=true" >> $GITHUB_OUTPUT
31 | fi
32 |
33 | integration:
34 | needs: access-secrets
35 | if: needs.access-secrets.outputs.secrets_available == 'true'
36 | runs-on: ${{ matrix.os }}
37 | strategy:
38 | fail-fast: false
39 | matrix:
40 | os: ["ubuntu-22.04", "macOS-13", "windows-2022"]
41 | python-version: ["3.10"]
42 |
43 | # Timeout: https://stackoverflow.com/a/59076067/4521646
44 | timeout-minutes: 25
45 | env:
46 | TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"
47 |
48 | steps:
49 | - uses: actions/checkout@v4
50 | - uses: actions/setup-python@v5
51 | with:
52 | python-version: ${{ matrix.python-version }}
53 | cache: "pip"
54 |
55 | - name: Install package & dependencies
56 | run: |
57 | pip --version
58 | pip install -e '.[test,extra]' -U -q --find-links $TORCH_URL
59 | pip list
60 |
61 | - name: Test integrations
62 | env:
63 | LIGHTNING_USER_ID: ${{ secrets.LIGHTNING_USER_ID }}
64 | LIGHTNING_API_KEY: ${{ secrets.LIGHTNING_API_KEY }}
65 | run: |
66 | coverage run --source litmodels -m pytest src tests -v -m cloud
67 | timeout-minutes: 15
68 |
69 | - name: Statistics
70 | run: |
71 | coverage report
72 | coverage xml
73 |
74 | - name: Upload coverage to Codecov
75 | uses: codecov/codecov-action@v5
76 | with:
77 | token: ${{ secrets.CODECOV_TOKEN }}
78 | files: ./coverage.xml
79 | flags: integrations
80 | env_vars: OS,PYTHON
81 | name: codecov-umbrella
82 | fail_ci_if_error: false
83 |
84 | integration-guardian:
85 | runs-on: ubuntu-latest
86 | needs: integration
87 | if: always()
88 | steps:
89 | - run: echo "${{ needs.integration.result }}"
90 | - name: failing...
91 | if: needs.integration.result == 'failure'
92 | run: exit 1
93 | - name: cancelled or skipped...
94 | if: contains(fromJSON('["cancelled", "skipped"]'), needs.integration.result)
95 | timeout-minutes: 1
96 | run: sleep 90
97 |
98 | # todo add job to report failing tests with cron
99 |
--------------------------------------------------------------------------------
/.github/workflows/ci-testing.yml:
--------------------------------------------------------------------------------
1 | name: CI testing
2 |
3 | # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows
4 | on: # Trigger the workflow on push or pull request, but only for the main branch
5 | push: {}
6 | pull_request:
7 | branches: [main]
8 |
9 | defaults:
10 | run:
11 | shell: bash
12 |
13 | jobs:
14 | pytester:
15 | runs-on: ${{ matrix.os }}
16 | strategy:
17 | fail-fast: false
18 | matrix:
19 | os: ["ubuntu-24.04", "macOS-13", "windows-2022"]
20 | python-version: ["3.9", "3.12"]
21 | requires: ["latest"]
22 | dependency: ["lightning"]
23 | include:
24 | - { requires: "oldest", dependency: "lightning", os: "ubuntu-22.04", python-version: "3.9" }
25 | - { requires: "latest", dependency: "pytorch_lightning", os: "ubuntu-24.04", python-version: "3.12" }
26 | - { requires: "latest", dependency: "pytorch_lightning", os: "windows-2022", python-version: "3.12" }
27 | - { requires: "latest", dependency: "pytorch_lightning", os: "macOS-13", python-version: "3.12" }
28 |
29 | # Timeout: https://stackoverflow.com/a/59076067/4521646
30 | timeout-minutes: 35
31 | env:
32 | TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"
33 |
34 | steps:
35 | - uses: actions/checkout@v4
36 | - name: Set up Python ${{ matrix.python-version }}
37 | uses: actions/setup-python@v5
38 | with:
39 | python-version: ${{ matrix.python-version }}
40 | cache: "pip"
41 |
42 | - name: Set min. dependencies
43 | if: matrix.requires == 'oldest'
44 | run: |
45 | pip install 'lightning-utilities[cli]'
46 | python -m lightning_utilities.cli requirements set-oldest --req_files='["requirements.txt"]'
47 |
48 | - name: Adjust requirements
49 | run: |
50 | pip install 'lightning-utilities[cli]' -U -q
51 | python -m lightning_utilities.cli requirements replace-pkg \
52 | --old_package="lightning" \
53 | --new_package="${{matrix.dependency}}" \
54 | --req_files='["_requirements/extra.txt"]'
55 | cat _requirements/extra.txt
56 |
57 | - name: Install package & dependencies
58 | run: |
59 | set -e
60 | pip --version
61 | pip install -e '.[test,extra]' -U -q --find-links $TORCH_URL
62 | pip list
63 | # check that right package was installed
64 | python -c "import ${{matrix.dependency}}; print(${{matrix.dependency}}.__version__)"
65 |
66 | - name: Tests with mocks
67 | run: |
68 | coverage run --source litmodels -m pytest src tests -v -m "not cloud"
69 |
70 | - name: Statistics
71 | run: |
72 | coverage report
73 | coverage xml
74 |
75 | - name: Upload coverage to Codecov
76 | uses: codecov/codecov-action@v5
77 | continue-on-error: true
78 | with:
79 | token: ${{ secrets.CODECOV_TOKEN }}
80 | files: ./coverage.xml
81 | flags: unittests
82 | env_vars: OS,PYTHON
83 | name: codecov-umbrella
84 | fail_ci_if_error: false
85 |
86 | tests-guardian:
87 | runs-on: ubuntu-latest
88 | needs: pytester
89 | if: always()
90 | steps:
91 | - run: echo "${{ needs.pytester.result }}"
92 | - name: failing...
93 | if: needs.pytester.result == 'failure'
94 | run: exit 1
95 | - name: cancelled or skipped...
96 | if: contains(fromJSON('["cancelled", "skipped"]'), needs.pytester.result)
97 | timeout-minutes: 1
98 | run: sleep 90
99 |
--------------------------------------------------------------------------------
/.github/workflows/cleanup-caches.yml:
--------------------------------------------------------------------------------
1 | # https://docs.github.com/en/actions/using-workflows/caching-dependencies-to-speed-up-workflows#force-deleting-cache-entries
2 | name: cleanup caches by a branch
3 | on:
4 | pull_request:
5 | types: [closed]
6 |
7 | jobs:
8 | pr-cleanup:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - name: Check out code
12 | uses: actions/checkout@v4
13 |
14 | - name: Cleanup
15 | run: |
16 | gh extension install actions/gh-actions-cache
17 |
18 | REPO=${{ github.repository }}
19 | BRANCH="refs/pull/${{ github.event.pull_request.number }}/merge"
20 |
21 | echo "Fetching list of cache key"
22 | cacheKeysForPR=$(gh actions-cache list -R $REPO -B $BRANCH | cut -f 1 )
23 |
24 | ## Setting this to not fail the workflow while deleting cache keys.
25 | set +e
26 | echo "Deleting caches..."
27 | for cacheKey in $cacheKeysForPR
28 | do
29 | gh actions-cache delete $cacheKey -R $REPO -B $BRANCH --confirm
30 | done
31 | echo "Done"
32 | env:
33 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
34 |
--------------------------------------------------------------------------------
/.github/workflows/docs-build.yml:
--------------------------------------------------------------------------------
1 | name: "Build (& deploy) Docs"
2 | on:
3 | push:
4 | branches: [main]
5 | workflow_dispatch:
6 |
7 | jobs:
8 | build-docs:
9 | uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@main
10 | with:
11 | requirements-file: "_requirements/_docs.txt"
12 |
13 | # https://github.com/marketplace/actions/deploy-to-github-pages
14 | docs-deploy:
15 | needs: build-docs
16 | runs-on: ubuntu-latest
17 | steps:
18 | - uses: actions/checkout@v4 # deploy needs git credentials
19 | - name: Download prepared docs
20 | uses: actions/download-artifact@v4
21 | with:
22 | name: docs-html-${{ github.sha }}
23 | path: docs/build/html
24 |
25 | - name: Deploy 🚀
26 | uses: JamesIves/github-pages-deploy-action@v4.7.3
27 | if: ${{ github.event_name == 'push' }}
28 | with:
29 | token: ${{ secrets.GITHUB_TOKEN }}
30 | branch: gh-pages # The branch the action should deploy to.
31 | folder: docs/build/html # The folder the action should deploy.
32 | clean: true # Automatically remove deleted files from the deploy branch
33 | target-folder: docs # If you'd like to push the contents of the deployment folder into a specific directory
34 | single-commit: true # you'd prefer to have a single commit on the deployment branch instead of full history
35 |
--------------------------------------------------------------------------------
/.github/workflows/greetings.yml:
--------------------------------------------------------------------------------
1 | name: Greetings
2 | # https://github.com/marketplace/actions/first-interaction
3 |
4 | on: [issues] # pull_request
5 |
6 | jobs:
7 | greeting:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - uses: actions/first-interaction@v1
11 | with:
12 | repo-token: ${{ secrets.GITHUB_TOKEN }}
13 | issue-message: "Hi! thanks for your contribution!, great first issue!"
14 | pr-message: "Hey thanks for the input! Please give us a bit of time to review it!"
15 |
--------------------------------------------------------------------------------
/.github/workflows/label-conflicts.yml:
--------------------------------------------------------------------------------
1 | name: Label conflicts
2 |
3 | on:
4 | push:
5 | branches: ["main"]
6 | pull_request_target:
7 | types: ["synchronize", "reopened", "opened"]
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}
11 | cancel-in-progress: false
12 |
13 | jobs:
14 | triage-conflicts:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: mschilde/auto-label-merge-conflicts@591722e97f3c4142df3eca156ed0dcf2bcd362bd # Oct 25, 2021
18 | with:
19 | CONFLICT_LABEL_NAME: "has conflicts"
20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
21 | MAX_RETRIES: 3
22 | WAIT_MS: 5000
23 |
--------------------------------------------------------------------------------
/.github/workflows/release-pypi.yml:
--------------------------------------------------------------------------------
1 | name: PyPI Release
2 |
3 | # https://help.github.com/en/actions/reference/events-that-trigger-workflows
4 | on: # Trigger the workflow on push or pull request, but only for the main branch
5 | push:
6 | branches: [main]
7 | release:
8 | types: [published]
9 |
10 | # based on https://github.com/pypa/gh-action-pypi-publish
11 |
12 | jobs:
13 | build:
14 | runs-on: ubuntu-22.04
15 | steps:
16 | - uses: actions/checkout@v4
17 | - uses: actions/setup-python@v5
18 | with:
19 | python-version: "3.10"
20 |
21 | - name: Install dependencies
22 | run: pip install -U -r _requirements/_ci.txt
23 | - name: Build package
24 | run: python -m build
25 | - name: Check package
26 | run: twine check dist/*
27 |
28 | # We do this, since failures on test.pypi aren't that bad
29 | - name: Publish to Test PyPI
30 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
31 | uses: pypa/gh-action-pypi-publish@v1.12.4
32 | with:
33 | user: __token__
34 | password: ${{ secrets.TEST_PYPI_PASSWORD }}
35 | repository-url: https://test.pypi.org/legacy/
36 |
37 | - name: Publish distribution 📦 to PyPI
38 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
39 | uses: pypa/gh-action-pypi-publish@v1.12.4
40 | with:
41 | user: __token__
42 | password: ${{ secrets.pypi_password }}
43 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Sphinx documentation
58 | docs/_build/
59 | docs/source/api/
60 | docs/source/*.md
61 | docs/source/readme.rst
62 |
63 | # PyBuilder
64 | target/
65 |
66 | # Jupyter Notebook
67 | .ipynb_checkpoints
68 |
69 | # IPython
70 | profile_default/
71 | ipython_config.py
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
77 | __pypackages__/
78 |
79 | # Celery stuff
80 | celerybeat-schedule
81 | celerybeat.pid
82 |
83 | # SageMath parsed files
84 | *.sage.py
85 |
86 | # Environments
87 | .env
88 | .venv
89 | env/
90 | venv/
91 | ENV/
92 | env.bak/
93 | venv.bak/
94 |
95 | # mypy
96 | .mypy_cache/
97 | .dmypy.json
98 | dmypy.json
99 |
100 | # Pyre type checker
101 | .pyre/
102 |
103 | # PyCharm
104 | .idea/
105 |
106 | # Lightning logs
107 | lightning_logs
108 | *.gz
109 | .DS_Store
110 | .*_submit.py
111 |
112 | # Ruff
113 | .ruff_cache/
114 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 |
4 | ci:
5 | autofix_prs: true
6 | autoupdate_commit_msg: "[pre-commit.ci] pre-commit suggestions"
7 | autoupdate_schedule: "monthly"
8 | # submodules: true
9 |
10 | repos:
11 | - repo: https://github.com/pre-commit/pre-commit-hooks
12 | rev: v5.0.0
13 | hooks:
14 | - id: end-of-file-fixer
15 | - id: trailing-whitespace
16 | exclude: README.md
17 | - id: check-case-conflict
18 | - id: check-yaml
19 | - id: check-toml
20 | - id: check-json
21 | - id: check-added-large-files
22 | - id: check-docstring-first
23 | - id: detect-private-key
24 |
25 | - repo: https://github.com/pre-commit/pygrep-hooks
26 | rev: v1.10.0
27 | hooks:
28 | - id: python-use-type-annotations
29 |
30 | - repo: https://github.com/codespell-project/codespell
31 | rev: v2.4.1
32 | hooks:
33 | - id: codespell
34 | additional_dependencies: [tomli]
35 | #args: ["--write-changes"]
36 |
37 | - repo: https://github.com/pre-commit/mirrors-prettier
38 | rev: v3.1.0
39 | hooks:
40 | - id: prettier
41 | files: \.(json|yml|yaml|toml)
42 | # https://prettier.io/docs/en/options.html#print-width
43 | args: ["--print-width=120"]
44 |
45 | - repo: https://github.com/executablebooks/mdformat
46 | rev: 0.7.22
47 | hooks:
48 | - id: mdformat
49 | args: ["--number"]
50 | additional_dependencies:
51 | - mdformat-gfm
52 | - mdformat-black
53 | - mdformat_frontmatter
54 | exclude: CHANGELOG.md
55 |
56 | - repo: https://github.com/astral-sh/ruff-pre-commit
57 | rev: v0.11.12
58 | hooks:
59 | - id: ruff
60 | args: ["--fix"]
61 | - id: ruff-format
62 | - id: ruff
63 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Build documentation in the docs/ directory with Sphinx
9 | sphinx:
10 | configuration: docs/source/conf.py
11 | fail_on_warning: true
12 |
13 | # Optionally build your docs in additional formats such as PDF and ePub
14 | formats: all
15 |
16 | # Optionally set the version of Python and requirements required to build your docs
17 | python:
18 | version: 3.7
19 | install:
20 | - requirements: _requirements/_docs.txt
21 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | All notable changes to this project will be documented in this file.
4 |
5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7 |
8 | ## [Unreleased] - YYYY-MM-DD
9 |
10 | ### Added
11 |
12 | ### Changed
13 |
14 | ### Fixed
15 |
16 | ### Removed
17 |
18 | ### Deprecated
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2021-2024 Lightning-AI team
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | # Manifest syntax https://docs.python.org/2/distutils/sourcedist.html
2 | graft wheelhouse
3 |
4 | recursive-exclude __pycache__ *.py[cod] *.orig
5 |
6 | # Include the README and CHANGELOG
7 | include *.md
8 | recursive-include src *.md
9 |
10 | # Include the license file
11 | include LICENSE
12 |
13 | # Exclude build configs
14 | exclude *.sh
15 | exclude *.toml
16 | exclude *.svg
17 | exclude *.yml
18 | exclude *.yaml
19 |
20 | # exclude tests from package
21 | recursive-exclude tests *
22 | recursive-exclude site *
23 | exclude tests
24 |
25 | # Exclude the documentation files
26 | recursive-exclude docs *
27 | exclude docs
28 |
29 | # Include the Requirements
30 | include requirements.txt
31 | recursive-include _requirements *.tx;t
32 |
33 | # Exclude Makefile
34 | exclude Makefile
35 |
36 | prune .git
37 | prune .github
38 | prune .circleci
39 | prune notebook*
40 | prune temp*
41 | prune test*
42 | prune benchmark*
43 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: test clean docs
2 |
3 | # to imitate SLURM set only single node
4 | export SLURM_LOCALID=0
5 | # assume you have installed need packages
6 | export SPHINX_MOCK_REQUIREMENTS=0
7 |
8 | test: clean
9 | pip install -q -r requirements.txt
10 | pip install -q -r _requirements/test.txt
11 |
12 | # use this to run tests
13 | python -m coverage run --source litmodels -m pytest src tests -v --flake8
14 | python -m coverage report
15 |
16 | docs: clean
17 | pip install . --quiet -r _requirements/docs.txt
18 | python -m sphinx -b html -W --keep-going docs/source docs/build
19 |
20 | clean:
21 | # clean all temp runs
22 | rm -rf $(shell find . -name "mlruns")
23 | rm -rf .mypy_cache
24 | rm -rf .pytest_cache
25 | rm -rf ./docs/build
26 | rm -rf ./docs/source/**/generated
27 | rm -rf ./docs/source/api
28 | rm -rf ./src/*.egg-info
29 | rm -rf ./build
30 | rm -rf ./dist
31 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Save, share and host AI model checkpoints Lightning fast ⚡
4 |
5 |
6 |
7 |
8 |
9 | ______________________________________________________________________
10 |
11 | Save, load, host, and share models without slowing down training.
12 | **LitModels** minimizes training slowdowns from checkpoint saving. Share public links on Lightning AI or your own cloud with enterprise-grade access controls.
13 |
14 |
15 |
16 |
17 | ✅ Checkpoint without slowing training. ✅ Granular access controls.
18 | ✅ Load models anywhere. ✅ Host on Lightning or your own cloud.
19 |
20 |
21 | [](https://discord.gg/WajDThKAur)
22 | 
23 | 
24 | [](https://codecov.io/gh/Lightning-AI/LitModels)
25 | [](https://github.com/Lightning-AI/LitModels/blob/main/LICENSE)
26 |
27 |
35 |
36 |
37 |
38 | # Quick start
39 |
40 | Install LitModels via pip:
41 |
42 | ```bash
43 | pip install litmodels
44 | ```
45 |
46 | Toy example ([see real examples](#examples)):
47 |
48 | ```python
49 | import litmodels as lm
50 | import torch
51 |
52 | # save a model
53 | model = torch.nn.Module()
54 | lm.save_model(model=model, name="model-name")
55 |
56 | # load a model
57 | model = lm.load_model(name="model-name")
58 | ```
59 |
60 | # Examples
61 |
62 |
63 | PyTorch
64 |
65 | Save model:
66 |
67 | ```python
68 | import torch
69 | from litmodels import save_model
70 |
71 | model = torch.nn.Module()
72 | save_model(model=model, name="your_org/your_team/torch-model")
73 | ```
74 |
75 | Load model:
76 |
77 | ```python
78 | from litmodels import load_model
79 |
80 | model_ = load_model(name="your_org/your_team/torch-model")
81 | ```
82 |
83 |
84 |
85 |
86 | PyTorch Lightning
87 |
88 | Save model:
89 |
90 | ```python
91 | from lightning import Trainer
92 | from litmodels import upload_model
93 | from litmodels.demos import BoringModel
94 |
95 | # Configure Lightning Trainer
96 | trainer = Trainer(max_epochs=2)
97 | # Define the model and train it
98 | trainer.fit(BoringModel())
99 |
100 | # Upload the best model to cloud storage
101 | checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
102 | # Define the model name - this should be unique to your model
103 | upload_model(model=checkpoint_path, name="//")
104 | ```
105 |
106 | Load model:
107 |
108 | ```python
109 | from lightning import Trainer
110 | from litmodels import download_model
111 | from litmodels.demos import BoringModel
112 |
113 | # Load the model from cloud storage
114 | checkpoint_path = download_model(
115 | # Define the model name and version - this needs to be unique to your model
116 | name="//:",
117 | download_dir="my_models",
118 | )
119 | print(f"model: {checkpoint_path}")
120 |
121 | # Train the model with extended training period
122 | trainer = Trainer(max_epochs=4)
123 | trainer.fit(BoringModel(), ckpt_path=checkpoint_path)
124 | ```
125 |
126 |
127 |
128 |
129 | TensorFlow / Keras
130 |
131 | Save model:
132 |
133 | ```python
134 | from tensorflow import keras
135 |
136 | from litmodels import save_model
137 |
138 | # Define the model
139 | model = keras.Sequential(
140 | [
141 | keras.layers.Dense(10, input_shape=(784,), name="dense_1"),
142 | keras.layers.Dense(10, name="dense_2"),
143 | ]
144 | )
145 |
146 | # Compile the model
147 | model.compile(optimizer="adam", loss="categorical_crossentropy")
148 |
149 | # Save the model
150 | save_model("lightning-ai/jirka/sample-tf-keras-model", model=model)
151 | ```
152 |
153 | Load model:
154 |
155 | ```python
156 | from litmodels import load_model
157 |
158 | model_ = load_model(
159 | "lightning-ai/jirka/sample-tf-keras-model", download_dir="./my-model"
160 | )
161 | ```
162 |
163 |
164 |
165 |
166 | SKLearn
167 |
168 | Save model:
169 |
170 | ```python
171 | from sklearn import datasets, model_selection, svm
172 | from litmodels import save_model
173 |
174 | # Load example dataset
175 | iris = datasets.load_iris()
176 | X, y = iris.data, iris.target
177 |
178 | # Split dataset into training and test sets
179 | X_train, X_test, y_train, y_test = model_selection.train_test_split(
180 | X, y, test_size=0.2, random_state=42
181 | )
182 |
183 | # Train a simple SVC model
184 | model = svm.SVC()
185 | model.fit(X_train, y_train)
186 |
187 | # Upload the saved model using litmodels
188 | save_model(model=model, name="your_org/your_team/sklearn-svm-model")
189 | ```
190 |
191 | Use model:
192 |
193 | ```python
194 | from litmodels import load_model
195 |
196 | # Download and load the model file from cloud storage
197 | model = load_model(
198 | name="your_org/your_team/sklearn-svm-model", download_dir="my_models"
199 | )
200 |
201 | # Example: run inference with the loaded model
202 | sample_input = [[5.1, 3.5, 1.4, 0.2]]
203 | prediction = model.predict(sample_input)
204 | print(f"Prediction: {prediction}")
205 | ```
206 |
207 |
208 |
209 | # Features
210 |
211 |
212 | PyTorch Lightning Callback
213 |
214 | Enhance your training process with an automatic checkpointing callback that uploads the model at the end of each epoch.
215 |
216 | ```python
217 | import torch.utils.data as data
218 | import torchvision as tv
219 | from lightning import Trainer
220 | from litmodels.integrations import LightningModelCheckpoint
221 | from litmodels.demos import BoringModel
222 |
223 | dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
224 | train, val = data.random_split(dataset, [55000, 5000])
225 |
226 | trainer = Trainer(
227 | max_epochs=2,
228 | callbacks=[
229 | LightningModelCheckpoint(
230 | # Define the model name - this should be unique to your model
231 | model_registry="//",
232 | )
233 | ],
234 | )
235 | trainer.fit(
236 | BoringModel(),
237 | data.DataLoader(train, batch_size=256),
238 | data.DataLoader(val, batch_size=256),
239 | )
240 | ```
241 |
242 |
243 |
244 |
245 | Save any Python class as a checkpoint
246 |
247 | Mixin classes streamline model management in Python by modularizing reusable functionalities like saving/loading, enabling consistent, conflict-free, and maintainable code across multiple models.
248 |
249 | **Save model:**
250 |
251 | ```python
252 | from litmodels.integrations.mixins import PickleRegistryMixin
253 |
254 |
255 | class MyModel(PickleRegistryMixin):
256 | def __init__(self, param1, param2):
257 | self.param1 = param1
258 | self.param2 = param2
259 | # Your model initialization code
260 | ...
261 |
262 |
263 | # Create and push a model instance
264 | model = MyModel(param1=42, param2="hello")
265 | model.upload_model(name="my-org/my-team/my-model")
266 | ```
267 |
268 | Load model:
269 |
270 | ```python
271 | loaded_model = MyModel.download_model(name="my-org/my-team/my-model")
272 | ```
273 |
274 |
275 |
276 |
277 | Save custom PyTorch models
278 |
279 | Mixin classes centralize serialization logic, eliminating redundant code and ensuring consistent, error-free model persistence across projects.
280 | The `download_model` method bypasses constructor arguments entirely, reconstructing the model directly from the registry with pre-configured architecture and weights, eliminating initialization mismatches.
281 |
282 | Save model:
283 |
284 | ```python
285 | import torch
286 | from litmodels.integrations.mixins import PyTorchRegistryMixin
287 |
288 |
289 | # Important: PyTorchRegistryMixin must be first in the inheritance order
290 | class MyTorchModel(PyTorchRegistryMixin, torch.nn.Module):
291 | def __init__(self, input_size, hidden_size=128):
292 | super().__init__()
293 | self.linear = torch.nn.Linear(input_size, hidden_size)
294 | self.activation = torch.nn.ReLU()
295 |
296 | def forward(self, x):
297 | return self.activation(self.linear(x))
298 |
299 |
300 | # Create and push the model
301 | model = MyTorchModel(input_size=784)
302 | model.upload_model(name="my-org/my-team/torch-model")
303 | ```
304 |
305 | Use the model:
306 |
307 | ```python
308 | # Pull the model with the same architecture
309 | loaded_model = MyTorchModel.download_model(name="my-org/my-team/torch-model")
310 | ```
311 |
312 |
313 |
314 | # Performance
315 |
316 |
319 |
320 | # Community
321 |
322 | 💬 [Get help on Discord](https://discord.com/invite/XncpTy7DSt)\
323 | 📋 [License: Apache 2.0](https://github.com/Lightning-AI/litModels/blob/main/LICENSE)
324 |
--------------------------------------------------------------------------------
/_requirements/_ci.txt:
--------------------------------------------------------------------------------
1 | build ==1.2.*
2 | twine ==6.1.*
3 |
--------------------------------------------------------------------------------
/_requirements/_docs.txt:
--------------------------------------------------------------------------------
1 | sphinx >=6.0,<7.0
2 | myst-parser >=2.0.0
3 | nbsphinx >=0.8.5
4 | pandoc >=1.0
5 | pypandoc-binary
6 | docutils >=0.16
7 | sphinxcontrib-fulltoc >=1.0
8 | sphinxcontrib-mockautodoc
9 | sphinx-autodoc-typehints >=1.0
10 | sphinx-paramlinks >=0.5.1
11 | sphinx-togglebutton >=0.2
12 | sphinx-copybutton >=0.3
13 | # jinja2 >=3.0.0,<3.2.0
14 |
15 | pt-lightning-sphinx-theme @ https://github.com/Lightning-AI/lightning_sphinx_theme/archive/master.zip
16 |
--------------------------------------------------------------------------------
/_requirements/extra.txt:
--------------------------------------------------------------------------------
1 | lightning >= 2.0.0
2 | numpy <2.0.0 ; platform_system == "Darwin" # compatibility fix for Torch
3 | joblib >= 1.0.0
4 |
--------------------------------------------------------------------------------
/_requirements/test.txt:
--------------------------------------------------------------------------------
1 | coverage >=5.0
2 | pytest >=6.0
3 | pytest-cov
4 | pytest-mock
5 |
6 | pytorch-lightning >=2.0
7 | scikit-learn >=1.0
8 | huggingface-hub >=0.29.0
9 | tensorflow >=2.0
10 |
--------------------------------------------------------------------------------
/_requirements/typing.txt:
--------------------------------------------------------------------------------
1 | mypy ==1.16.0
2 |
--------------------------------------------------------------------------------
/docs/.build_docs.sh:
--------------------------------------------------------------------------------
1 | make clean
2 | make html --debug --jobs $(nproc)
3 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS = -W
6 | SPHINXBUILD = python $(shell which sphinx-build)
7 | SOURCEDIR = source
8 | BUILDDIR = build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
20 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/_static/copybutton.js:
--------------------------------------------------------------------------------
1 | /* Copied from the official Python docs: https://docs.python.org/3/_static/copybutton.js */
2 | $(document).ready(function () {
3 | /* Add a [>>>] button on the top-right corner of code samples to hide
4 | * the >>> and ... prompts and the output and thus make the code
5 | * copyable. */
6 | var div = $(
7 | ".highlight-python .highlight," +
8 | ".highlight-python3 .highlight," +
9 | ".highlight-pycon .highlight," +
10 | ".highlight-default .highlight",
11 | );
12 | var pre = div.find("pre");
13 |
14 | // get the styles from the current theme
15 | pre.parent().parent().css("position", "relative");
16 | var hide_text = "Hide the prompts and output";
17 | var show_text = "Show the prompts and output";
18 | var border_width = pre.css("border-top-width");
19 | var border_style = pre.css("border-top-style");
20 | var border_color = pre.css("border-top-color");
21 | var button_styles = {
22 | cursor: "pointer",
23 | position: "absolute",
24 | top: "0",
25 | right: "0",
26 | "border-color": border_color,
27 | "border-style": border_style,
28 | "border-width": border_width,
29 | color: border_color,
30 | "text-size": "75%",
31 | "font-family": "monospace",
32 | "padding-left": "0.2em",
33 | "padding-right": "0.2em",
34 | "border-radius": "0 3px 0 0",
35 | };
36 |
37 | // create and add the button to all the code blocks that contain >>>
38 | div.each(function (index) {
39 | var jthis = $(this);
40 | if (jthis.find(".gp").length > 0) {
41 | var button = $('>>> ');
42 | button.css(button_styles);
43 | button.attr("title", hide_text);
44 | button.data("hidden", "false");
45 | jthis.prepend(button);
46 | }
47 | // tracebacks (.gt) contain bare text elements that need to be
48 | // wrapped in a span to work with .nextUntil() (see later)
49 | jthis
50 | .find("pre:has(.gt)")
51 | .contents()
52 | .filter(function () {
53 | return this.nodeType == 3 && this.data.trim().length > 0;
54 | })
55 | .wrap("");
56 | });
57 |
58 | // define the behavior of the button when it's clicked
59 | $(".copybutton").click(function (e) {
60 | e.preventDefault();
61 | var button = $(this);
62 | if (button.data("hidden") === "false") {
63 | // hide the code output
64 | button.parent().find(".go, .gp, .gt").hide();
65 | button.next("pre").find(".gt").nextUntil(".gp, .go").css("visibility", "hidden");
66 | button.css("text-decoration", "line-through");
67 | button.attr("title", show_text);
68 | button.data("hidden", "true");
69 | } else {
70 | // show the code output
71 | button.parent().find(".go, .gp, .gt").show();
72 | button.next("pre").find(".gt").nextUntil(".gp, .go").css("visibility", "visible");
73 | button.css("text-decoration", "none");
74 | button.attr("title", hide_text);
75 | button.data("hidden", "false");
76 | }
77 | });
78 | });
79 |
--------------------------------------------------------------------------------
/docs/source/_static/images/icon.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/docs/source/_static/images/logo-large.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/docs/source/_static/images/logo-small.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/docs/source/_static/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/LitModels/7717ff0bc7e76745d9458305257ab19fd2346f5e/docs/source/_static/images/logo.png
--------------------------------------------------------------------------------
/docs/source/_static/images/logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/docs/source/_templates/theme_variables.jinja:
--------------------------------------------------------------------------------
1 | {%- set external_urls = {
2 | 'github': 'https://github.com/Lightning-AI/litModels',
3 | 'github_issues': 'https://github.com/Lightning-AI/litModels/issues',
4 | 'contributing': 'https://github.com/Lightning-AI/lightning/blob/master/CONTRIBUTING.md',
5 | 'governance': 'https://github.com/Lightning-AI/lightning/blob/master/governance.md',
6 | 'docs': 'https://lightning-ai.github.io/models/',
7 | 'twitter': 'https://twitter.com/LightningAI',
8 | 'discuss': 'https://discord.com/invite/tfXFetEZxv',
9 | 'tutorials': 'https://lightning.ai',
10 | 'previous_pytorch_versions': 'https://lightning-ai.github.io/models/',
11 | 'home': 'https://lightning-ai.github.io/models/',
12 | 'get_started': 'https://lightning.ai',
13 | 'features': 'https://lightning-ai.github.io/models/',
14 | 'blog': 'https://www.Lightning.ai/blog',
15 | 'resources': 'https://lightning.ai',
16 | 'support': 'https://lightning-ai.github.io/models/',
17 | }
18 | -%}
19 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | #
2 | # Configuration file for the Sphinx documentation builder.
3 | #
4 | # This file does only contain a selection of the most common options. For a
5 | # full list see the documentation:
6 | # http://www.sphinx-doc.org/en/master/config
7 | import glob
8 | import inspect
9 | import os
10 | import re
11 | import sys
12 | from importlib.util import module_from_spec, spec_from_file_location
13 |
14 | import pt_lightning_sphinx_theme
15 | import pypandoc
16 |
17 | _PATH_HERE = os.path.abspath(os.path.dirname(__file__))
18 | _PATH_ROOT = os.path.realpath(os.path.join(_PATH_HERE, "..", ".."))
19 | _PATH_SOURCE = os.path.join(_PATH_ROOT, "src")
20 | sys.path.insert(0, os.path.abspath(_PATH_ROOT))
21 |
22 | SPHINX_MOCK_REQUIREMENTS = int(os.environ.get("SPHINX_MOCK_REQUIREMENTS", True))
23 |
24 | # alternative https://stackoverflow.com/a/67692/4521646
25 | spec = spec_from_file_location(
26 | "litmodels/__about__.py", os.path.join(_PATH_SOURCE, "litmodels", "__about__.py")
27 | )
28 | about = module_from_spec(spec)
29 | spec.loader.exec_module(about)
30 |
31 | # -- Project information -----------------------------------------------------
32 |
33 | # this name shall match the project name in Github as it is used for linking to code
34 | project = "models"
35 | copyright = about.__copyright__
36 | author = about.__author__
37 |
38 | # The short X.Y version
39 | version = about.__version__
40 | # The full version, including alpha/beta/rc tags
41 | release = about.__version__
42 |
43 | # Options for the linkcode extension
44 | # ----------------------------------
45 | github_user = "Lightning-AI"
46 | github_repo = project
47 |
48 | # -- Project documents -------------------------------------------------------
49 |
50 |
51 | def _transform_changelog(path_in: str, path_out: str) -> None:
52 | with open(path_in) as fp:
53 | chlog_lines = fp.readlines()
54 | # enrich short subsub-titles to be unique
55 | chlog_ver = ""
56 | for i, ln in enumerate(chlog_lines):
57 | if ln.startswith("## "):
58 | chlog_ver = ln[2:].split("-")[0].strip()
59 | elif ln.startswith("### "):
60 | ln = ln.replace("###", f"### {chlog_ver} -")
61 | chlog_lines[i] = ln
62 | with open(path_out, "w") as fp:
63 | fp.writelines(chlog_lines)
64 |
65 |
66 | def _convert_markdown(path_in: str, path_out: str) -> None:
67 | with open(path_in) as fp:
68 | readme = fp.read()
69 | # TODO: temp fix removing SVG badges and GIF, because they are automatically 100% wide
70 | readme = re.sub(r"(\[!\[.*\))", "", readme)
71 | readme = re.sub(r"(!\[.*.gif\))", "", readme)
72 | folder_names = (
73 | os.path.basename(p)
74 | for p in glob.glob(os.path.join(_PATH_ROOT, "*"))
75 | if os.path.isdir(p)
76 | )
77 | for dir_name in folder_names:
78 | readme = readme.replace(
79 | "](%s/" % dir_name, "](%s/" % os.path.join(_PATH_ROOT, dir_name)
80 | )
81 | readme = pypandoc.convert_text(readme, format="markdown", to="rst")
82 | with open(path_out, "w") as fp:
83 | fp.write(readme)
84 |
85 |
86 | # export the READme
87 | _convert_markdown(os.path.join(_PATH_ROOT, "README.md"), "readme.rst")
88 |
89 | # -- General configuration ---------------------------------------------------
90 |
91 | # If your documentation needs a minimal Sphinx version, state it here.
92 |
93 | needs_sphinx = "6.2"
94 |
95 | # Add any Sphinx extension module names here, as strings. They can be
96 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
97 | # ones.
98 | extensions = [
99 | "sphinx.ext.autodoc",
100 | # 'sphinxcontrib.mockautodoc', # raises error: directive 'automodule' is already registered ...
101 | # 'sphinxcontrib.fulltoc', # breaks pytorch-theme with unexpected kw argument 'titles_only'
102 | "sphinx.ext.doctest",
103 | "sphinx.ext.intersphinx",
104 | "sphinx.ext.todo",
105 | "sphinx.ext.coverage",
106 | "sphinx.ext.linkcode",
107 | "sphinx.ext.autosummary",
108 | "sphinx.ext.napoleon",
109 | "sphinx.ext.imgmath",
110 | "myst_parser",
111 | "sphinx.ext.autosectionlabel",
112 | "nbsphinx",
113 | "sphinx_autodoc_typehints",
114 | "sphinx_paramlinks",
115 | "sphinx.ext.githubpages",
116 | "pt_lightning_sphinx_theme.extensions.lightning",
117 | ]
118 |
119 | # Add any paths that contain templates here, relative to this directory.
120 | templates_path = ["_templates"]
121 |
122 | myst_update_mathjax = False
123 |
124 | # https://berkeley-stat159-f17.github.io/stat159-f17/lectures/14-sphinx..html#conf.py-(cont.)
125 | # https://stackoverflow.com/questions/38526888/embed-ipython-notebook-in-sphinx-document
126 | # I execute the notebooks manually in advance. If notebooks test the code,
127 | # they should be run at build time.
128 | nbsphinx_execute = "never"
129 | nbsphinx_allow_errors = True
130 | nbsphinx_requirejs_path = ""
131 |
132 | # The suffix(es) of source filenames.
133 | # You can specify multiple suffix as a list of string:
134 | #
135 | # source_suffix = ['.rst', '.md']
136 | # source_suffix = ['.rst', '.md', '.ipynb']
137 | source_suffix = {
138 | ".rst": "restructuredtext",
139 | ".txt": "markdown",
140 | ".md": "markdown",
141 | ".ipynb": "nbsphinx",
142 | }
143 |
144 | # The master toctree document.
145 | master_doc = "index"
146 |
147 | # The language for content autogenerated by Sphinx. Refer to documentation
148 | # for a list of supported languages.
149 | #
150 | # This is also used if you do content translation via gettext catalogs.
151 | # Usually you set "language" from the command line for these cases.
152 | language = "en"
153 |
154 | # List of patterns, relative to source directory, that match files and
155 | # directories to ignore when looking for source files.
156 | # This pattern also affects html_static_path and html_extra_path.
157 | exclude_patterns = [
158 | "PULL_REQUEST_TEMPLATE.md",
159 | ]
160 |
161 | # The name of the Pygments (syntax highlighting) style to use.
162 | pygments_style = None
163 |
164 | # -- Options for HTML output -------------------------------------------------
165 |
166 | # The theme to use for HTML and HTML Help pages. See the documentation for
167 | # a list of builtin themes.
168 | #
169 | html_theme = "pt_lightning_sphinx_theme"
170 | html_theme_path = [pt_lightning_sphinx_theme.get_html_theme_path()]
171 |
172 | # Theme options are theme-specific and customize the look and feel of a theme
173 | # further. For a list of options available for each theme, see the
174 | # documentation.
175 |
176 | html_theme_options = {
177 | "pytorch_project": about.__homepage__,
178 | "canonical_url": about.__homepage__,
179 | "collapse_navigation": False,
180 | "display_version": True,
181 | "logo_only": False,
182 | }
183 |
184 | html_favicon = "_static/images/icon.svg"
185 |
186 | # Add any paths that contain custom static files (such as style sheets) here,
187 | # relative to this directory. They are copied after the builtin static files,
188 | # so a file named "default.css" will overwrite the builtin "default.css".
189 | html_static_path = ["_templates", "_static"]
190 |
191 | # Custom sidebar templates, must be a dictionary that maps document names
192 | # to template names.
193 | #
194 | # The default sidebars (for documents that don't match any pattern) are
195 | # defined by theme itself. Builtin themes are using these templates by
196 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
197 | # 'searchbox.html']``.
198 | #
199 | # html_sidebars = {}
200 |
201 | # -- Options for HTMLHelp output ---------------------------------------------
202 |
203 | # Output file base name for HTML help builder.
204 | htmlhelp_basename = project + "-doc"
205 |
206 | # -- Options for LaTeX output ------------------------------------------------
207 |
208 | latex_elements = {
209 | # The paper size ('letterpaper' or 'a4paper').
210 | # 'papersize': 'letterpaper',
211 | # The font size ('10pt', '11pt' or '12pt').
212 | # 'pointsize': '10pt',
213 | # Additional stuff for the LaTeX preamble.
214 | # 'preamble': '',
215 | # Latex figure (float) alignment
216 | "figure_align": "htbp",
217 | }
218 |
219 | # Grouping the document tree into LaTeX files. List of tuples
220 | # (source start file, target name, title,
221 | # author, documentclass [howto, manual, or own class]).
222 | latex_documents = [
223 | (master_doc, project + ".tex", project + " Documentation", author, "manual"),
224 | ]
225 |
226 | # -- Options for manual page output ------------------------------------------
227 |
228 | # One entry per manual page. List of tuples
229 | # (source start file, name, description, authors, manual section).
230 | man_pages = [(master_doc, project, project + " Documentation", [author], 1)]
231 |
232 | # -- Options for Texinfo output ----------------------------------------------
233 |
234 | # Grouping the document tree into Texinfo files. List of tuples
235 | # (source start file, target name, title, author,
236 | # dir menu entry, description, category)
237 | texinfo_documents = [
238 | (
239 | master_doc,
240 | project,
241 | project + " Documentation",
242 | author,
243 | project,
244 | about.__docs__,
245 | "Miscellaneous",
246 | ),
247 | ]
248 |
249 | # -- Options for Epub output -------------------------------------------------
250 |
251 | # Bibliographic Dublin Core info.
252 | epub_title = project
253 |
254 | # The unique identifier of the text. This can be a ISBN number
255 | # or the project homepage.
256 | #
257 | # epub_identifier = ''
258 |
259 | # A unique identification for the text.
260 | #
261 | # epub_uid = ''
262 |
263 | # A list of files that should not be packed into the epub file.
264 | epub_exclude_files = ["search.html"]
265 |
266 | # -- Extension configuration -------------------------------------------------
267 |
268 | # -- Options for intersphinx extension ---------------------------------------
269 |
270 | # Example configuration for intersphinx: refer to the Python standard library.
271 | intersphinx_mapping = {
272 | "python": ("https://docs.python.org/3", None),
273 | "torch": ("https://pytorch.org/docs/stable/", None),
274 | "numpy": ("https://numpy.org/doc/stable/", None),
275 | }
276 |
277 | # -- Options for to-do extension ----------------------------------------------
278 |
279 | # If true, `todo` and `todoList` produce output, else they produce nothing.
280 | todo_include_todos = True
281 |
282 |
283 | def setup(app):
284 | # this is for hiding doctest decoration,
285 | # see: http://z4r.github.io/python/2011/12/02/hides-the-prompts-and-output/
286 | app.add_js_file("copybutton.js")
287 |
288 |
289 | # Ignoring Third-party packages
290 | # https://stackoverflow.com/questions/15889621/sphinx-how-to-exclude-imports-in-automodule
291 | def _package_list_from_file(file):
292 | list_pkgs = []
293 | with open(file) as fp:
294 | lines = fp.readlines()
295 | for ln in lines:
296 | found = [ln.index(ch) for ch in list(",=<>#") if ch in ln]
297 | pkg = ln[: min(found)] if found else ln
298 | if pkg.rstrip():
299 | list_pkgs.append(pkg.rstrip())
300 | return list_pkgs
301 |
302 |
303 | # define mapping from PyPI names to python imports
304 | PACKAGE_MAPPING = {
305 | "PyYAML": "yaml",
306 | }
307 | MOCK_PACKAGES = []
308 | if SPHINX_MOCK_REQUIREMENTS:
309 | # mock also base packages when we are on RTD since we don't install them there
310 | MOCK_PACKAGES += _package_list_from_file(
311 | os.path.join(_PATH_ROOT, "requirements.txt")
312 | )
313 | MOCK_PACKAGES = [PACKAGE_MAPPING.get(pkg, pkg) for pkg in MOCK_PACKAGES]
314 |
315 | autodoc_mock_imports = MOCK_PACKAGES
316 |
317 |
318 | # Resolve function
319 | # This function is used to populate the (source) links in the API
320 | def linkcode_resolve(domain, info):
321 | def find_source():
322 | # try to find the file and line number, based on code from numpy:
323 | # https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L286
324 | obj = sys.modules[info["module"]]
325 | for part in info["fullname"].split("."):
326 | obj = getattr(obj, part)
327 | fname = inspect.getsourcefile(obj)
328 | # https://github.com/rtfd/readthedocs.org/issues/5735
329 | if any(s in fname for s in ("readthedocs", "rtfd", "checkouts")):
330 | # /home/docs/checkouts/readthedocs.org/user_builds/pytorch_lightning/checkouts/
331 | # devel/pytorch_lightning/utilities/cls_experiment.py#L26-L176
332 | path_top = os.path.abspath(os.path.join("..", "..", ".."))
333 | fname = os.path.relpath(fname, start=path_top)
334 | else:
335 | # Local build, imitate master
336 | fname = "master/" + os.path.relpath(fname, start=os.path.abspath(".."))
337 | source, lineno = inspect.getsourcelines(obj)
338 | return fname, lineno, lineno + len(source) - 1
339 |
340 | if domain != "py" or not info["module"]:
341 | return None
342 | try:
343 | filename = "%s#L%d-L%d" % find_source()
344 | except Exception:
345 | filename = info["module"].replace(".", "/") + ".py"
346 | # import subprocess
347 | # tag = subprocess.Popen(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE,
348 | # universal_newlines=True).communicate()[0][:-1]
349 | branch = filename.split("/")[0]
350 | # do mapping from latest tags to master
351 | branch = {"latest": "main", "stable": "main"}.get(branch, branch)
352 | filename = "/".join([branch] + filename.split("/")[1:])
353 | return f"https://github.com/{github_user}/{github_repo}/blob/{filename}"
354 |
355 |
356 | autosummary_generate = True
357 |
358 | autodoc_member_order = "groupwise"
359 | autoclass_content = "both"
360 | # the options are fixed and will be soon in release,
361 | # see https://github.com/sphinx-doc/sphinx/issues/5459
362 | autodoc_default_options = {
363 | "members": None,
364 | "methods": None,
365 | # 'attributes': None,
366 | "special-members": "__call__",
367 | "exclude-members": "_abc_impl",
368 | "show-inheritance": True,
369 | "private-members": True,
370 | "noindex": True,
371 | }
372 |
373 | # Sphinx will add “permalinks” for each heading and description environment as paragraph signs that
374 | # become visible when the mouse hovers over them.
375 | # This value determines the text for the permalink; it defaults to "¶". Set it to None or the empty
376 | # string to disable permalinks.
377 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#confval-html_add_permalinks
378 | # html_add_permalinks = "¶"
379 |
380 | # True to prefix each section label with the name of the document it is in, followed by a colon.
381 | # For example, index:Introduction for a section called Introduction that appears in document index.rst.
382 | # Useful for avoiding ambiguity when the same section heading appears in different documents.
383 | # http://www.sphinx-doc.org/en/master/usage/extensions/autosectionlabel.html
384 | autosectionlabel_prefix_document = True
385 |
386 | # only run doctests marked with a ".. doctest::" directive
387 | doctest_test_doctest_blocks = ""
388 | doctest_global_setup = """
389 | """
390 | coverage_skip_undoc_in_source = True
391 |
392 | linkcheck_ignore = [
393 | # ignore the following URLs
394 | "https://github.com/gridai/lit-logger",
395 | ]
396 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. Lightning-AI-Sandbox documentation master file, created by
2 | sphinx-quickstart on Wed Mar 25 21:34:07 2020.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Lightning-Models
7 | ================
8 |
9 | .. toctree::
10 | :maxdepth: 1
11 | :name: start
12 | :caption: Start here
13 |
14 | readme
15 |
16 |
17 | Indices and tables
18 | ==================
19 |
20 | * :ref:`genindex`
21 | * :ref:`modindex`
22 | * :ref:`search`
23 |
--------------------------------------------------------------------------------
/examples/demo-tensorflow-keras.py:
--------------------------------------------------------------------------------
1 | from tensorflow import keras
2 |
3 | from litmodels import load_model, save_model
4 |
5 | if __name__ == "__main__":
6 | # Define the model
7 | model = keras.Sequential([
8 | keras.layers.Dense(10, input_shape=(784,), name="dense_1"),
9 | keras.layers.Dense(10, name="dense_2"),
10 | ])
11 |
12 | # Compile the model
13 | model.compile(optimizer="adam", loss="categorical_crossentropy")
14 |
15 | # Save the model
16 | save_model("lightning-ai/jirka/sample-tf-keras-model", model=model)
17 |
18 | # Load the model
19 | model_ = load_model("lightning-ai/jirka/sample-tf-keras-model", download_dir="./my-model")
20 |
--------------------------------------------------------------------------------
/examples/demo-upload-download.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from lightning.pytorch.demos.boring_classes import BoringModel
3 |
4 | import litmodels
5 |
6 | if __name__ == "__main__":
7 | # Define your model
8 | model = BoringModel()
9 |
10 | # Save the model's state dictionary
11 | torch.save(model.state_dict(), "./boring-checkpoint.pt")
12 |
13 | # Upload the model checkpoint
14 | litmodels.upload_model(
15 | "./boring-checkpoint.pt",
16 | "lightning-ai/jirka/lit-boring-model",
17 | )
18 |
19 | # Download the model checkpoint
20 | model_path = litmodels.download_model("lightning-ai/jirka/lit-boring-model", download_dir="./my-models")
21 | print(f"Model downloaded to {model_path}")
22 |
23 | # Load the model checkpoint
24 | loaded_model = BoringModel()
25 | loaded_model.load_state_dict(torch.load("./boring-checkpoint.pt"))
26 | print(loaded_model)
27 |
--------------------------------------------------------------------------------
/examples/resume-lightning-training.py:
--------------------------------------------------------------------------------
1 | """
2 | This example demonstrates how to resume training of a model using the `download_model` function.
3 | """
4 |
5 | import os
6 |
7 | from lightning import Trainer
8 | from lightning.pytorch.demos.boring_classes import BoringModel
9 |
10 | from litmodels import download_model
11 |
12 | # Define the model name - this should be unique to your model
13 | # The format is //:
14 | MY_MODEL_NAME = "lightning-ai/jirka/lit-boring-callback:latest"
15 |
16 |
17 | if __name__ == "__main__":
18 | model_files = download_model(name=MY_MODEL_NAME, download_dir="my_models")
19 | model_path = os.path.join("my_models", model_files[0])
20 | print(f"model: {model_path}")
21 |
22 | trainer = Trainer(max_epochs=4)
23 | trainer.fit(
24 | BoringModel(),
25 | ckpt_path=model_path,
26 | )
27 |
--------------------------------------------------------------------------------
/examples/train-model-and-simple-save.py:
--------------------------------------------------------------------------------
1 | """
2 | This example demonstrates how to train a model and upload it to the cloud using the `upload_model` function.
3 | """
4 |
5 | from lightning import Trainer
6 | from lightning.pytorch.demos.boring_classes import BoringModel
7 |
8 | from litmodels import upload_model
9 |
10 | # Define the model name - this should be unique to your model
11 | # The format is //
12 | MY_MODEL_NAME = "lightning-ai/jirka/lit-boring-simple"
13 |
14 |
15 | if __name__ == "__main__":
16 | trainer = Trainer(max_epochs=2)
17 | trainer.fit(BoringModel())
18 | checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
19 | print(f"best: {checkpoint_path}")
20 | upload_model(model=checkpoint_path, name=MY_MODEL_NAME)
21 |
--------------------------------------------------------------------------------
/examples/train-model-with-lightning-callback.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a model with a Lightning callback that uploads the best model to the cloud after each epoch.
3 | """
4 |
5 | from lightning import Trainer
6 | from lightning.pytorch.demos.boring_classes import BoringModel
7 |
8 | from litmodels.integrations import LightningModelCheckpoint
9 |
10 | # Define the model name - this should be unique to your model
11 | # The format is //
12 | MY_MODEL_NAME = "lightning-ai/jirka/lit-boring-callback"
13 |
14 |
15 | if __name__ == "__main__":
16 | trainer = Trainer(
17 | max_epochs=2,
18 | callbacks=LightningModelCheckpoint(model_registry=MY_MODEL_NAME),
19 | )
20 | trainer.fit(BoringModel())
21 |
--------------------------------------------------------------------------------
/examples/train-model-with-lightning-logger.py:
--------------------------------------------------------------------------------
1 | """
2 | # Enhanced Logging with LightningLogger
3 |
4 | Integrate with [LitLogger](https://github.com/gridai/lit-logger) to automatically log your model checkpoints
5 | and training metrics to cloud storage.
6 | Though the example utilizes PyTorch Lightning, this integration concept works across various model training frameworks.
7 |
8 | """
9 |
10 | from lightning import Trainer
11 | from lightning.pytorch.demos.boring_classes import BoringModel
12 | from litlogger import LightningLogger
13 |
14 |
15 | class DemoModel(BoringModel):
16 | def training_step(self, batch, batch_idx):
17 | output = super().training_step(batch, batch_idx)
18 | self.log("train_loss", output["loss"])
19 | return output
20 |
21 |
22 | if __name__ == "__main__":
23 | # configure the logger
24 | lit_logger = LightningLogger(log_model=True)
25 |
26 | # pass logger to the Trainer
27 | trainer = Trainer(max_epochs=5, logger=lit_logger)
28 |
29 | # train the model
30 | trainer.fit(model=DemoModel())
31 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [metadata]
2 | license_file = "LICENSE"
3 | description-file = "README.md"
4 |
5 | [build-system]
6 | requires = [
7 | "setuptools",
8 | "wheel",
9 | ]
10 |
11 |
12 | [tool.check-manifest]
13 | ignore = [
14 | "*.yml",
15 | ".github",
16 | ".github/*"
17 | ]
18 |
19 |
20 | [tool.pytest.ini_options]
21 | norecursedirs = [
22 | ".git",
23 | ".github",
24 | "dist",
25 | "build",
26 | "docs",
27 | ]
28 | addopts = [
29 | "--strict-markers",
30 | "--doctest-modules",
31 | "--color=yes",
32 | "--disable-pytest-warnings",
33 | ]
34 | markers = [
35 | "cloud:Run the cloud tests for example",
36 | ]
37 | filterwarnings = [
38 | "error::FutureWarning",
39 | ]
40 | xfail_strict = true
41 | junit_duration_report = "call"
42 |
43 | [tool.coverage.report]
44 | exclude_lines = [
45 | "pragma: no cover",
46 | "pass",
47 | ]
48 |
49 | [tool.codespell]
50 | #skip = '*.py'
51 | quiet-level = 3
52 | # comma separated list of words; waiting for:
53 | # https://github.com/codespell-project/codespell/issues/2839#issuecomment-1731601603
54 | # also adding links until they ignored by its: nature
55 | # https://github.com/codespell-project/codespell/issues/2243#issuecomment-1732019960
56 | #ignore-words-list = ""
57 |
58 |
59 | [tool.docformatter]
60 | recursive = true
61 | wrap-summaries = 120
62 | wrap-descriptions = 120
63 | blank = true
64 |
65 |
66 | [tool.mypy]
67 | files = [
68 | "src",
69 | ]
70 | install_types = true
71 | non_interactive = true
72 | disallow_untyped_defs = true
73 | ignore_missing_imports = true
74 | show_error_codes = true
75 | warn_redundant_casts = true
76 | warn_unused_configs = true
77 | warn_unused_ignores = true
78 | allow_redefinition = true
79 | # disable this rule as the Trainer attributes are defined in the connectors, not in its __init__
80 | disable_error_code = "attr-defined"
81 | # style choices
82 | warn_no_return = false
83 |
84 |
85 | [tool.ruff]
86 | line-length = 120
87 | target-version = 'py38'
88 | exclude = [
89 | "build",
90 | "dist",
91 | "docs"
92 | ]
93 |
94 | [tool.ruff.format]
95 | preview = true
96 |
97 |
98 | [tool.ruff.lint]
99 | select = [
100 | "E", "W", # see: https://pypi.org/project/pycodestyle
101 | "F", # see: https://pypi.org/project/pyflakes
102 | "D", # see: https://pypi.org/project/pydocstyle
103 | "N", # see: https://pypi.org/project/pep8-naming
104 | "RUF018", # see: https://docs.astral.sh/ruff/rules/assignment-in-assert
105 | "UP", # see: https://docs.astral.sh/ruff/rules/#pyupgrade-up
106 | "I", # implementation for isort
107 | ]
108 | extend-select = [
109 | "C4", # see: https://pypi.org/project/flake8-comprehensions
110 | "PT", # see: https://pypi.org/project/flake8-pytest-style
111 | "RET", # see: https://pypi.org/project/flake8-return
112 | "SIM", # see: https://pypi.org/project/flake8-simplify
113 | ]
114 | ignore = [
115 | "E731", # Do not assign a lambda expression, use a def
116 | "D100", # Missing docstring in public module
117 | ]
118 |
119 | [tool.ruff.lint.per-file-ignores]
120 | "setup.py" = ["D100", "SIM115"]
121 | "__about__.py" = ["D100"]
122 | "__init__.py" = ["D100", "E402"]
123 | "examples/**" = ["D"] # todo
124 | "tests/**" = ["D"]
125 |
126 | [tool.ruff.lint.pydocstyle]
127 | # Use Google-style docstrings.
128 | convention = "google"
129 |
130 | #[tool.ruff.pycodestyle]
131 | #ignore-overlong-task-comments = true
132 |
133 | [tool.ruff.lint.mccabe]
134 | # Unlike Flake8, default to a complexity level of 10.
135 | max-complexity = 10
136 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | lightning-sdk >=0.2.11
2 | lightning-utilities
3 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import glob
3 | import os
4 | from importlib.util import module_from_spec, spec_from_file_location
5 | from pathlib import Path
6 |
7 | from pkg_resources import parse_requirements
8 | from setuptools import find_packages, setup
9 |
10 | _PATH_ROOT = os.path.dirname(__file__)
11 | _PATH_SOURCE = os.path.join(_PATH_ROOT, "src")
12 | _PATH_REQUIRES = os.path.join(_PATH_ROOT, "_requirements")
13 |
14 |
15 | def _load_py_module(fname, pkg="litmodels"):
16 | spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_SOURCE, pkg, fname))
17 | py = module_from_spec(spec)
18 | spec.loader.exec_module(py)
19 | return py
20 |
21 |
22 | def _load_requirements(path_dir: str = _PATH_ROOT, file_name: str = "requirements.txt") -> list:
23 | reqs = parse_requirements(open(os.path.join(path_dir, file_name)).readlines())
24 | return list(map(str, reqs))
25 |
26 |
27 | about = _load_py_module("__about__.py")
28 | with open(os.path.join(_PATH_ROOT, "README.md"), encoding="utf-8") as fopen:
29 | readme = fopen.read()
30 |
31 |
32 | def _prepare_extras(requirements_dir: str = _PATH_REQUIRES, skip_files: tuple = ("devel.txt",)) -> dict:
33 | # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras
34 | # Define package extras. These are only installed if you specify them.
35 | # From remote, use like `pip install pytorch-lightning[dev, docs]`
36 | # From local copy of repo, use like `pip install ".[dev, docs]"`
37 | req_files = [Path(p) for p in glob.glob(os.path.join(requirements_dir, "*.txt"))]
38 | extras = {
39 | p.stem: _load_requirements(file_name=p.name, path_dir=str(p.parent))
40 | for p in req_files
41 | # ignore some development specific requirements
42 | if p.name not in skip_files and not p.name.startswith("_")
43 | }
44 | # todo: eventually add some custom aggregations such as `develop`
45 | extras = {name: sorted(set(reqs)) for name, reqs in extras.items()}
46 | print("The extras are: ", extras)
47 | return extras
48 |
49 |
50 | # https://packaging.python.org/discussions/install-requires-vs-requirements /
51 | # keep the meta-data here for simplicity in reading this file... it's not obvious
52 | # what happens and to non-engineers they won't know to look in init ...
53 | # the goal of the project is simplicity for researchers, don't want to add too much
54 | # engineer specific practices
55 | setup(
56 | name="litmodels",
57 | version=about.__version__,
58 | description=about.__docs__,
59 | author=about.__author__,
60 | author_email=about.__author_email__,
61 | url=about.__homepage__,
62 | download_url="https://github.com/Lightning-AI/litModels",
63 | license=about.__license__,
64 | packages=find_packages(where="src"),
65 | package_dir={"": "src"},
66 | long_description=readme,
67 | long_description_content_type="text/markdown",
68 | include_package_data=True,
69 | zip_safe=False,
70 | keywords=["deep learning", "pytorch", "AI"],
71 | python_requires=">=3.8",
72 | setup_requires=["wheel"],
73 | install_requires=_load_requirements(),
74 | extras_require=_prepare_extras(),
75 | project_urls={
76 | "Bug Tracker": "https://github.com/Lightning-AI/litModels/issues",
77 | "Documentation": "https://lightning-ai.github.io/litModels/",
78 | "Source Code": "https://github.com/Lightning-AI/litModels",
79 | },
80 | classifiers=[
81 | "Environment :: Console",
82 | "Natural Language :: English",
83 | # How mature is this project? Common values are
84 | # 3 - Alpha, 4 - Beta, 5 - Production/Stable
85 | "Development Status :: 3 - Alpha",
86 | # Indicate who your project is intended for
87 | "Intended Audience :: Developers",
88 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
89 | "Topic :: Scientific/Engineering :: Information Analysis",
90 | # Pick your license as you wish
91 | "License :: OSI Approved :: Apache Software License",
92 | "Operating System :: OS Independent",
93 | # Specify the Python versions you support here. In particular, ensure
94 | # that you indicate whether you support Python 2, Python 3 or both.
95 | "Programming Language :: Python :: 3",
96 | "Programming Language :: Python :: 3.8",
97 | "Programming Language :: Python :: 3.9",
98 | "Programming Language :: Python :: 3.10",
99 | "Programming Language :: Python :: 3.11",
100 | "Programming Language :: Python :: 3.12",
101 | "Programming Language :: Python :: 3.13",
102 | ],
103 | )
104 |
--------------------------------------------------------------------------------
/src/litmodels/__about__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.1.8"
2 | __author__ = "Lightning-AI et al."
3 | __author_email__ = "community@lightning.ai"
4 | __license__ = "Apache-2.0"
5 | __copyright__ = f"Copyright (c) 2024-2025, {__author__}."
6 | __homepage__ = "https://github.com/Lightning-AI/litModels"
7 | __docs__ = "Lightning AI Model hub."
8 |
9 | __all__ = [
10 | "__author__",
11 | "__author_email__",
12 | "__copyright__",
13 | "__docs__",
14 | "__homepage__",
15 | "__license__",
16 | "__version__",
17 | ]
18 |
--------------------------------------------------------------------------------
/src/litmodels/__init__.py:
--------------------------------------------------------------------------------
1 | """Root package info."""
2 |
3 | import os
4 |
5 | from litmodels.__about__ import * # noqa: F401, F403
6 |
7 | _PACKAGE_ROOT = os.path.dirname(__file__)
8 | _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
9 |
10 | from litmodels.io import download_model, load_model, save_model, upload_model, upload_model_files # noqa: F401
11 |
12 | __all__ = ["download_model", "upload_model", "load_model", "save_model"]
13 |
--------------------------------------------------------------------------------
/src/litmodels/demos/__init__.py:
--------------------------------------------------------------------------------
1 | """Define a demos model for examples and testing purposes."""
2 |
3 | from lightning_utilities import module_available
4 |
5 | __all__ = []
6 |
7 | if module_available("lightning"):
8 | from lightning.pytorch.demos.boring_classes import BoringModel, DemoModel
9 |
10 | __all__ += ["BoringModel", "DemoModel"]
11 | elif module_available("pytorch_lightning"):
12 | from pytorch_lightning.demos.boring_classes import BoringModel, DemoModel
13 |
14 | __all__ += ["BoringModel", "DemoModel"]
15 |
--------------------------------------------------------------------------------
/src/litmodels/integrations/__init__.py:
--------------------------------------------------------------------------------
1 | """Integrations with training frameworks like PyTorch Lightning, TensorFlow, and others."""
2 |
3 | from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE
4 |
5 | __all__ = []
6 |
7 | if _LIGHTNING_AVAILABLE:
8 | from litmodels.integrations.checkpoints import LightningModelCheckpoint
9 |
10 | __all__ += ["LightningModelCheckpoint"]
11 |
12 | if _PYTORCHLIGHTNING_AVAILABLE:
13 | from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint
14 |
15 | __all__ += ["PytorchLightningModelCheckpoint"]
16 |
--------------------------------------------------------------------------------
/src/litmodels/integrations/checkpoints.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import os.path
3 | import queue
4 | import threading
5 | from abc import ABC
6 | from datetime import datetime
7 | from functools import lru_cache
8 | from pathlib import Path
9 | from typing import TYPE_CHECKING, Any, Optional, Union
10 |
11 | from lightning_sdk.lightning_cloud.login import Auth
12 | from lightning_sdk.utils.resolve import _resolve_teamspace
13 | from lightning_utilities import StrEnum
14 | from lightning_utilities.core.rank_zero import rank_zero_debug, rank_zero_only, rank_zero_warn
15 |
16 | from litmodels import upload_model
17 | from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE
18 | from litmodels.io.cloud import _list_available_teamspaces, delete_model_version
19 |
20 | if _LIGHTNING_AVAILABLE:
21 | from lightning.pytorch.callbacks import ModelCheckpoint as _LightningModelCheckpoint
22 |
23 |
24 | if _PYTORCHLIGHTNING_AVAILABLE:
25 | from pytorch_lightning.callbacks import ModelCheckpoint as _PytorchLightningModelCheckpoint
26 |
27 |
28 | if TYPE_CHECKING:
29 | if _LIGHTNING_AVAILABLE:
30 | import lightning.pytorch as pl
31 | if _PYTORCHLIGHTNING_AVAILABLE:
32 | import pytorch_lightning as pl
33 |
34 |
35 | # Create a singleton upload manager
36 | @lru_cache(maxsize=None)
37 | def get_model_manager() -> "ModelManager":
38 | """Get or create the singleton upload manager."""
39 | return ModelManager()
40 |
41 |
42 | # enumerate the possible actions
43 | class Action(StrEnum):
44 | """Enumeration of possible actions for the ModelManager."""
45 |
46 | UPLOAD = "upload"
47 | REMOVE = "remove"
48 |
49 |
50 | class RemoveType(StrEnum):
51 | """Enumeration of possible remove types for the ModelManager."""
52 |
53 | LOCAL = "local"
54 | CLOUD = "cloud"
55 |
56 |
57 | class ModelManager:
58 | """Manages uploads and removals with a single queue but separate counters."""
59 |
60 | task_queue: queue.Queue
61 |
62 | def __init__(self) -> None:
63 | """Initialize the ModelManager with a task queue and counters."""
64 | self.task_queue = queue.Queue()
65 | self.upload_count = 0
66 | self.remove_count = 0
67 | self._worker = threading.Thread(target=self._worker_loop, daemon=True)
68 | self._worker.start()
69 |
70 | def __getstate__(self) -> dict:
71 | """Get the state of the ModelManager for pickling."""
72 | state = self.__dict__.copy()
73 | del state["task_queue"]
74 | del state["_worker"]
75 | return state
76 |
77 | def __setstate__(self, state: dict) -> None:
78 | """Set the state of the ModelManager after unpickling."""
79 | self.__dict__.update(state)
80 | import queue
81 | import threading
82 |
83 | self.task_queue = queue.Queue()
84 | self._worker = threading.Thread(target=self._worker_loop, daemon=True)
85 | self._worker.start()
86 |
87 | def _worker_loop(self) -> None:
88 | while True:
89 | task = self.task_queue.get()
90 | if task is None:
91 | self.task_queue.task_done()
92 | break
93 | action, detail = task
94 | if action == Action.UPLOAD:
95 | registry_name, filepath, metadata = detail
96 | try:
97 | upload_model(name=registry_name, model=filepath, metadata=metadata)
98 | rank_zero_debug(f"Finished uploading: {filepath}")
99 | except Exception as ex:
100 | rank_zero_warn(f"Upload failed {filepath}: {ex}")
101 | finally:
102 | self.upload_count -= 1
103 | elif action == Action.REMOVE:
104 | filepath, trainer, registry_name = detail
105 | try:
106 | if registry_name:
107 | rank_zero_debug(f"Removing from cloud: {filepath}")
108 | # Remove from the cloud
109 | version = os.path.splitext(os.path.basename(filepath))[0]
110 | delete_model_version(name=registry_name, version=version)
111 | if trainer:
112 | rank_zero_debug(f"Removed local file: {filepath}")
113 | trainer.strategy.remove_checkpoint(filepath)
114 | except Exception as ex:
115 | rank_zero_warn(f"Removal failed {filepath}: {ex}")
116 | finally:
117 | self.remove_count -= 1
118 | else:
119 | rank_zero_warn(f"Unknown task: {task}")
120 | self.task_queue.task_done()
121 |
122 | def queue_upload(self, registry_name: str, filepath: Union[str, Path], metadata: Optional[dict] = None) -> None:
123 | """Queue an upload task."""
124 | self.upload_count += 1
125 | self.task_queue.put((Action.UPLOAD, (registry_name, filepath, metadata)))
126 | rank_zero_debug(f"Queued upload: {filepath} (pending uploads: {self.upload_count})")
127 |
128 | def queue_remove(
129 | self, filepath: Union[str, Path], trainer: Optional["pl.Trainer"] = None, registry_name: Optional[str] = None
130 | ) -> None:
131 | """Queue a removal task."""
132 | self.remove_count += 1
133 | self.task_queue.put((Action.REMOVE, (filepath, trainer, registry_name)))
134 | rank_zero_debug(f"Queued removal: {filepath} (pending removals: {self.remove_count})")
135 |
136 | def shutdown(self) -> None:
137 | """Shut down the manager and wait for all tasks to complete."""
138 | self.task_queue.put(None)
139 | self.task_queue.join()
140 | rank_zero_debug("Manager shut down.")
141 |
142 |
143 | # Base class to be inherited
144 | class LitModelCheckpointMixin(ABC):
145 | """Mixin class for LitModel checkpoint functionality."""
146 |
147 | _datetime_stamp: str
148 | model_registry: Optional[str] = None
149 | _model_manager: ModelManager
150 |
151 | def __init__(
152 | self, model_registry: Optional[str], keep_all_uploaded: bool = False, clear_all_local: bool = False
153 | ) -> None:
154 | """Initialize with model name.
155 |
156 | Args:
157 | model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
158 | keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so.
159 | clear_all_local: Whether to clear local models after uploading to the cloud.
160 | """
161 | if not model_registry:
162 | rank_zero_warn(
163 | "The model is not defined so we will continue with LightningModule names and timestamp of now"
164 | )
165 | self._datetime_stamp = datetime.now().strftime("%Y%m%d-%H%M")
166 | # remove any / from beginning and end of the name
167 | self.model_registry = model_registry.strip("/") if model_registry else None
168 | self._keep_all_uploaded = keep_all_uploaded
169 | self._clear_all_local = clear_all_local
170 |
171 | try: # authenticate before anything else starts
172 | Auth().authenticate()
173 | except Exception:
174 | raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.")
175 |
176 | self._model_manager = ModelManager()
177 |
178 | @rank_zero_only
179 | def _upload_model(self, trainer: "pl.Trainer", filepath: Union[str, Path], metadata: Optional[dict] = None) -> None:
180 | if not self.model_registry:
181 | raise RuntimeError(
182 | "Model name is not specified neither updated by `setup` method via Trainer."
183 | " Please set the model name before uploading or ensure that `setup` method is called."
184 | )
185 | model_registry = self.model_registry
186 | if os.path.isfile(filepath):
187 | # parse the file name as version
188 | version, _ = os.path.splitext(os.path.basename(filepath))
189 | model_registry += f":{version}"
190 | if not metadata:
191 | metadata = {}
192 | # Add the integration name to the metadata
193 | mro = inspect.getmro(type(self))
194 | abc_index = mro.index(LitModelCheckpointMixin)
195 | ckpt_class = mro[abc_index - 1]
196 | metadata.update({"litModels.integration": ckpt_class.__name__})
197 | # Add to queue instead of uploading directly
198 | get_model_manager().queue_upload(registry_name=model_registry, filepath=filepath, metadata=metadata)
199 | if self._clear_all_local:
200 | get_model_manager().queue_remove(filepath=filepath, trainer=trainer)
201 |
202 | @rank_zero_only
203 | def _remove_model(self, trainer: "pl.Trainer", filepath: Union[str, Path]) -> None:
204 | """Remove the local version of the model if requested."""
205 | get_model_manager().queue_remove(
206 | filepath=filepath,
207 | # skip the local removal we put it in the queue right after the upload
208 | trainer=None if self._clear_all_local else trainer,
209 | # skip the cloud removal if we keep all uploaded models
210 | registry_name=None if self._keep_all_uploaded else self.model_registry,
211 | )
212 |
213 | def default_model_name(self, pl_model: "pl.LightningModule") -> str:
214 | """Generate a default model name based on the class name and timestamp."""
215 | return pl_model.__class__.__name__ + f"_{self._datetime_stamp}"
216 |
217 | def _update_model_name(self, pl_model: "pl.LightningModule") -> None:
218 | """Update the model name if not already set."""
219 | count_slashes_in_name = self.model_registry.count("/") if self.model_registry else 0
220 | default_model_name = self.default_model_name(pl_model)
221 | if count_slashes_in_name > 2:
222 | raise ValueError(
223 | f"Invalid model name: '{self.model_registry}'. It should not contain more than two '/' character."
224 | )
225 | if count_slashes_in_name == 2:
226 | # user has defined the model name in the format 'organization/teamspace/modelname'
227 | return
228 | if count_slashes_in_name == 1:
229 | # user had defined only the teamspace name
230 | self.model_registry = f"{self.model_registry}/{default_model_name}"
231 | elif count_slashes_in_name == 0:
232 | if not self.model_registry:
233 | self.model_registry = default_model_name
234 | teamspace = _resolve_teamspace(None, None, None)
235 | if teamspace:
236 | # case you use default model name and teamspace determined from env. variables aka running in studio
237 | self.model_registry = f"{teamspace.owner.name}/{teamspace.name}/{self.model_registry}"
238 | else: # try to load default users teamspace
239 | ts_names = list(_list_available_teamspaces().keys())
240 | if len(ts_names) == 1:
241 | self.model_registry = f"{ts_names[0]}/{self.model_registry}"
242 | else:
243 | options = "\n\t".join(ts_names)
244 | raise RuntimeError(
245 | f"Teamspace is not defined and there are multiple teamspaces available:\n{options}"
246 | )
247 | else:
248 | raise RuntimeError(f"Invalid model name: '{self.model_registry}'")
249 |
250 |
251 | # Create specific implementations
252 | if _LIGHTNING_AVAILABLE:
253 |
254 | class LightningModelCheckpoint(LitModelCheckpointMixin, _LightningModelCheckpoint):
255 | """Lightning ModelCheckpoint with LitModel support.
256 |
257 | Args:
258 | model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
259 | keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so.
260 | clear_all_local: Whether to clear local models after uploading to the cloud.
261 | *args: Additional arguments to pass to the parent class.
262 | **kwargs: Additional keyword arguments to pass to the parent class.
263 | """
264 |
265 | def __init__(
266 | self,
267 | *args: Any,
268 | model_name: Optional[str] = None,
269 | model_registry: Optional[str] = None,
270 | keep_all_uploaded: bool = False,
271 | clear_all_local: bool = False,
272 | **kwargs: Any,
273 | ) -> None:
274 | """Initialize the checkpoint with model name and other parameters."""
275 | _LightningModelCheckpoint.__init__(self, *args, **kwargs)
276 | if model_name is not None:
277 | rank_zero_warn(
278 | "The 'model_name' argument is deprecated and will be removed in a future version."
279 | " Please use 'model_registry' instead."
280 | )
281 | LitModelCheckpointMixin.__init__(
282 | self,
283 | model_registry=model_registry or model_name,
284 | keep_all_uploaded=keep_all_uploaded,
285 | clear_all_local=clear_all_local,
286 | )
287 |
288 | def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
289 | """Setup the checkpoint callback."""
290 | _LightningModelCheckpoint.setup(self, trainer, pl_module, stage)
291 | self._update_model_name(pl_module)
292 |
293 | def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
294 | """Extend the save checkpoint method to upload the model."""
295 | _LightningModelCheckpoint._save_checkpoint(self, trainer, filepath)
296 | if trainer.is_global_zero: # Only upload from the main process
297 | self._upload_model(trainer=trainer, filepath=filepath)
298 |
299 | def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
300 | """Extend the on_fit_end method to ensure all uploads are completed."""
301 | _LightningModelCheckpoint.on_fit_end(self, trainer, pl_module)
302 | # Wait for all uploads to finish
303 | get_model_manager().shutdown()
304 |
305 | def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
306 | """Extend the remove checkpoint method to remove the model from the registry."""
307 | if trainer.is_global_zero: # Only remove from the main process
308 | self._remove_model(trainer=trainer, filepath=filepath)
309 |
310 |
311 | if _PYTORCHLIGHTNING_AVAILABLE:
312 |
313 | class PytorchLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightningModelCheckpoint):
314 | """PyTorch Lightning ModelCheckpoint with LitModel support.
315 |
316 | Args:
317 | model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
318 | keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so.
319 | clear_all_local: Whether to clear local models after uploading to the cloud.
320 | args: Additional arguments to pass to the parent class.
321 | kwargs: Additional keyword arguments to pass to the parent class.
322 | """
323 |
324 | def __init__(
325 | self,
326 | *args: Any,
327 | model_name: Optional[str] = None,
328 | model_registry: Optional[str] = None,
329 | keep_all_uploaded: bool = False,
330 | clear_all_local: bool = False,
331 | **kwargs: Any,
332 | ) -> None:
333 | """Initialize the checkpoint with model name and other parameters."""
334 | _PytorchLightningModelCheckpoint.__init__(self, *args, **kwargs)
335 | if model_name is not None:
336 | rank_zero_warn(
337 | "The 'model_name' argument is deprecated and will be removed in a future version."
338 | " Please use 'model_registry' instead."
339 | )
340 | LitModelCheckpointMixin.__init__(
341 | self,
342 | model_registry=model_registry or model_name,
343 | keep_all_uploaded=keep_all_uploaded,
344 | clear_all_local=clear_all_local,
345 | )
346 |
347 | def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
348 | """Setup the checkpoint callback."""
349 | _PytorchLightningModelCheckpoint.setup(self, trainer, pl_module, stage)
350 | self._update_model_name(pl_module)
351 |
352 | def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
353 | """Extend the save checkpoint method to upload the model."""
354 | _PytorchLightningModelCheckpoint._save_checkpoint(self, trainer, filepath)
355 | if trainer.is_global_zero: # Only upload from the main process
356 | self._upload_model(trainer=trainer, filepath=filepath)
357 |
358 | def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
359 | """Extend the on_fit_end method to ensure all uploads are completed."""
360 | _PytorchLightningModelCheckpoint.on_fit_end(self, trainer, pl_module)
361 | # Wait for all uploads to finish
362 | get_model_manager().shutdown()
363 |
364 | def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
365 | """Extend the remove checkpoint method to remove the model from the registry."""
366 | if trainer.is_global_zero: # Only remove from the main process
367 | self._remove_model(trainer=trainer, filepath=filepath)
368 |
--------------------------------------------------------------------------------
/src/litmodels/integrations/duplicate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import tempfile
4 | from pathlib import Path
5 | from typing import Optional
6 |
7 | from lightning_utilities import module_available
8 |
9 | from litmodels.io import upload_model_files
10 |
11 | if module_available("huggingface_hub"):
12 | from huggingface_hub import snapshot_download
13 | else:
14 | snapshot_download = None
15 |
16 |
17 | def duplicate_hf_model(
18 | hf_model: str,
19 | lit_model: Optional[str] = None,
20 | local_workdir: Optional[str] = None,
21 | verbose: int = 1,
22 | metadata: Optional[dict] = None,
23 | ) -> str:
24 | """Downloads the model from Hugging Face and uploads it to Lightning Cloud.
25 |
26 | Args:
27 | hf_model: The name of the Hugging Face model to duplicate.
28 | lit_model: The name of the Lightning Cloud model to create.
29 | local_workdir:
30 | The local working directory to use for the duplication process. If not set a temp folder will be created.
31 | verbose: Shot a progress bar for the upload.
32 | metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
33 |
34 | Returns:
35 | The name of the duplicated model in Lightning Cloud.
36 | """
37 | if not snapshot_download:
38 | raise ModuleNotFoundError(
39 | "Hugging Face Hub is not installed. Please install it with `pip install huggingface_hub`."
40 | )
41 |
42 | if not local_workdir:
43 | local_workdir = tempfile.mkdtemp()
44 | local_workdir = Path(local_workdir)
45 | model_name = hf_model.replace("/", "_")
46 |
47 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
48 | # Download the model from Hugging Face
49 | snapshot_download(
50 | repo_id=hf_model,
51 | revision="main", # Branch/tag/commit
52 | repo_type="model", # Options: "dataset", "model", "space"
53 | local_dir=local_workdir / model_name, # Specify to save in custom location, default is cache
54 | local_dir_use_symlinks=True, # Use symlinks to save disk space
55 | ignore_patterns=[".cache*"], # Exclude certain files if needed
56 | max_workers=os.cpu_count(), # Number of parallel downloads
57 | )
58 | # prune cache in the downloaded model
59 | for path in local_workdir.rglob(".cache*"):
60 | shutil.rmtree(path)
61 |
62 | # Upload the model to Lightning Cloud
63 | if not lit_model:
64 | lit_model = model_name
65 | if not metadata:
66 | metadata = {}
67 | metadata.update({"litModels.integration": "duplicate_hf_model", "hf_model": hf_model})
68 | model = upload_model_files(name=lit_model, path=local_workdir / model_name, verbose=verbose, metadata=metadata)
69 | return model.name
70 |
--------------------------------------------------------------------------------
/src/litmodels/integrations/imports.py:
--------------------------------------------------------------------------------
1 | import operator
2 |
3 | from lightning_utilities import compare_version, module_available
4 |
5 | _LIGHTNING_AVAILABLE = module_available("lightning")
6 | _LIGHTNING_GREATER_EQUAL_2_5_1 = compare_version("lightning", operator.ge, "2.5.1")
7 | _PYTORCHLIGHTNING_AVAILABLE = module_available("pytorch_lightning")
8 | _PYTORCHLIGHTNING_GREATER_EQUAL_2_5_1 = compare_version("pytorch_lightning", operator.ge, "2.5.1")
9 |
--------------------------------------------------------------------------------
/src/litmodels/integrations/mixins.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import json
3 | import tempfile
4 | import warnings
5 | from abc import ABC
6 | from pathlib import Path
7 | from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
8 |
9 | from lightning_utilities.core.rank_zero import rank_zero_warn
10 |
11 | from litmodels.io.cloud import download_model_files, upload_model_files
12 | from litmodels.io.utils import dump_pickle, load_pickle
13 |
14 | if TYPE_CHECKING:
15 | import torch
16 |
17 |
18 | class ModelRegistryMixin(ABC):
19 | """Mixin for model registry integration."""
20 |
21 | def upload_model(
22 | self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
23 | ) -> None:
24 | """Push the model to the registry.
25 |
26 | Args:
27 | name: The name of the model. If not use the class name.
28 | version: The version of the model. If None, the latest version is used.
29 | temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
30 | """
31 |
32 | @classmethod
33 | def download_model(
34 | cls, name: str, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
35 | ) -> object:
36 | """Pull the model from the registry.
37 |
38 | Args:
39 | name: The name of the model.
40 | version: The version of the model. If None, the latest version is used.
41 | temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
42 | """
43 |
44 | def _setup(
45 | self, name: Optional[str] = None, temp_folder: Union[str, Path, None] = None
46 | ) -> Tuple[str, str, Union[str, Path]]:
47 | """Parse and validate the model name and temporary folder."""
48 | if name is None:
49 | name = model_name = self.__class__.__name__
50 | elif ":" in name:
51 | raise ValueError(f"Invalid model name: '{name}'. It should not contain ':' associated with version.")
52 | else:
53 | model_name = name.split("/")[-1]
54 | if temp_folder is None:
55 | temp_folder = tempfile.mkdtemp()
56 | return name, model_name, temp_folder
57 |
58 | def _upload_model_files(
59 | self, name: str, path: Union[str, Path, List[Union[str, Path]]], metadata: Optional[dict] = None
60 | ) -> None:
61 | """Upload the model files to the registry."""
62 | if not metadata:
63 | metadata = {}
64 | # Add the integration name to the metadata
65 | mro = inspect.getmro(type(self))
66 | abc_index = mro.index(ModelRegistryMixin)
67 | mixin_class = mro[abc_index - 1]
68 | metadata.update({"litModels.integration": mixin_class.__name__})
69 | upload_model_files(name=name, path=path, metadata=metadata)
70 |
71 |
72 | class PickleRegistryMixin(ModelRegistryMixin):
73 | """Mixin for pickle registry integration."""
74 |
75 | def upload_model(
76 | self,
77 | name: Optional[str] = None,
78 | version: Optional[str] = None,
79 | temp_folder: Union[str, Path, None] = None,
80 | metadata: Optional[dict] = None,
81 | ) -> None:
82 | """Push the model to the registry.
83 |
84 | Args:
85 | name: The name of the model. If not use the class name.
86 | version: The version of the model. If None, the latest version is used.
87 | temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
88 | metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
89 | """
90 | name, model_name, temp_folder = self._setup(name, temp_folder)
91 | pickle_path = Path(temp_folder) / f"{model_name}.pkl"
92 | dump_pickle(model=self, path=pickle_path)
93 | if version:
94 | name = f"{name}:{version}"
95 | self._upload_model_files(name=name, path=pickle_path, metadata=metadata)
96 |
97 | @classmethod
98 | def download_model(
99 | cls, name: str, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
100 | ) -> object:
101 | """Pull the model from the registry.
102 |
103 | Args:
104 | name: The name of the model.
105 | version: The version of the model. If None, the latest version is used.
106 | temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
107 | """
108 | if temp_folder is None:
109 | temp_folder = tempfile.mkdtemp()
110 | model_registry = f"{name}:{version}" if version else name
111 | files = download_model_files(name=model_registry, download_dir=temp_folder)
112 | pkl_files = [f for f in files if f.endswith(".pkl")]
113 | if not pkl_files:
114 | raise RuntimeError(f"No pickle file found for model: {model_registry} with {files}")
115 | if len(pkl_files) > 1:
116 | raise RuntimeError(f"Multiple pickle files found for model: {model_registry} with {pkl_files}")
117 | pkl_path = Path(temp_folder) / pkl_files[0]
118 | obj = load_pickle(path=pkl_path)
119 | if not isinstance(obj, cls):
120 | raise RuntimeError(f"Unpickled object is not of type {cls.__name__}: {type(obj)}")
121 | return obj
122 |
123 |
124 | class PyTorchRegistryMixin(ModelRegistryMixin):
125 | """Mixin for PyTorch model registry integration."""
126 |
127 | def __new__(cls, *args: Any, **kwargs: Any) -> "torch.nn.Module":
128 | """Create a new instance of the class without calling __init__."""
129 | instance = super().__new__(cls)
130 |
131 | # Get __init__ signature excluding 'self'
132 | init_sig = inspect.signature(cls.__init__)
133 | params = list(init_sig.parameters.values())[1:] # Skip self
134 |
135 | # Create temporary signature for binding
136 | temp_sig = init_sig.replace(parameters=params)
137 |
138 | # Bind and apply defaults
139 | bound_args = temp_sig.bind(*args, **kwargs)
140 | bound_args.apply_defaults()
141 |
142 | # Store unified kwargs
143 | instance.__init_kwargs = bound_args.arguments
144 | return instance
145 |
146 | def upload_model(
147 | self,
148 | name: Optional[str] = None,
149 | version: Optional[str] = None,
150 | temp_folder: Union[str, Path, None] = None,
151 | metadata: Optional[dict] = None,
152 | ) -> None:
153 | """Push the model to the registry.
154 |
155 | Args:
156 | name: The name of the model. If not use the class name.
157 | version: The version of the model. If None, the latest version is used.
158 | temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
159 | metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
160 | """
161 | import torch
162 |
163 | # Ensure that the model is in evaluation mode
164 | if not isinstance(self, torch.nn.Module):
165 | raise TypeError(f"The model must be a PyTorch `nn.Module` but got: {type(self)}")
166 |
167 | name, model_name, temp_folder = self._setup(name, temp_folder)
168 |
169 | init_kwargs_path = None
170 | if self.__init_kwargs:
171 | try:
172 | # Save the model arguments to a JSON file
173 | init_kwargs_path = Path(temp_folder) / f"{model_name}__init_kwargs.json"
174 | with open(init_kwargs_path, "w") as fp:
175 | json.dump(self.__init_kwargs, fp)
176 | except Exception as ex:
177 | raise RuntimeError(
178 | f"Failed to save model arguments: {ex}."
179 | " Ensure the model's arguments are JSON serializable or use `PickleRegistryMixin`."
180 | ) from ex
181 | elif not hasattr(self, "__init_kwargs"):
182 | rank_zero_warn(
183 | "The child class is missing `__init_kwargs`."
184 | " Ensure `PyTorchRegistryMixin` is first in the inheritance order"
185 | " or call `PyTorchRegistryMixin.__init__` explicitly in the child class."
186 | )
187 |
188 | torch_state_dict_path = Path(temp_folder) / f"{model_name}.pth"
189 | torch.save(self.state_dict(), torch_state_dict_path)
190 | model_registry = f"{name}:{version}" if version else name
191 | # todo: consider creating another temp folder and copying these two files
192 | # todo: updating SDK to support uploading just specific files
193 | uploaded_files = [torch_state_dict_path]
194 | if init_kwargs_path:
195 | uploaded_files.append(init_kwargs_path)
196 | self._upload_model_files(name=model_registry, path=uploaded_files, metadata=metadata)
197 |
198 | @classmethod
199 | def download_model(
200 | cls,
201 | name: str,
202 | version: Optional[str] = None,
203 | temp_folder: Union[str, Path, None] = None,
204 | torch_load_kwargs: Optional[dict] = None,
205 | ) -> "torch.nn.Module":
206 | """Pull the model from the registry.
207 |
208 | Args:
209 | name: The name of the model.
210 | version: The version of the model. If None, the latest version is used.
211 | temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
212 | torch_load_kwargs: Additional arguments to pass to `torch.load()`.
213 | """
214 | import torch
215 |
216 | if temp_folder is None:
217 | temp_folder = tempfile.mkdtemp()
218 | model_registry = f"{name}:{version}" if version else name
219 | files = download_model_files(name=model_registry, download_dir=temp_folder)
220 |
221 | torch_files = [f for f in files if f.endswith(".pth")]
222 | if not torch_files:
223 | raise RuntimeError(f"No torch file found for model: {model_registry} with {files}")
224 | if len(torch_files) > 1:
225 | raise RuntimeError(f"Multiple torch files found for model: {model_registry} with {torch_files}")
226 | state_dict_path = Path(temp_folder) / torch_files[0]
227 | # ignore future warning about changed default
228 | with warnings.catch_warnings():
229 | warnings.simplefilter("ignore", category=FutureWarning)
230 | state_dict = torch.load(state_dict_path, **(torch_load_kwargs if torch_load_kwargs else {}))
231 |
232 | init_files = [fp for fp in files if fp.endswith("__init_kwargs.json")]
233 | if not init_files:
234 | init_kwargs = {}
235 | elif len(init_files) > 1:
236 | raise RuntimeError(f"Multiple init files found for model: {model_registry} with {init_files}")
237 | else:
238 | init_kwargs_path = Path(temp_folder) / init_files[0]
239 | with open(init_kwargs_path) as fp:
240 | init_kwargs = json.load(fp)
241 |
242 | # Create a new model instance without calling __init__
243 | instance = cls(**init_kwargs)
244 | if not isinstance(instance, torch.nn.Module):
245 | raise TypeError(f"The model must be a PyTorch `nn.Module` but got: {type(instance)}")
246 | # Now load the state dict on the instance
247 | instance.load_state_dict(state_dict, strict=True)
248 | return instance
249 |
--------------------------------------------------------------------------------
/src/litmodels/io/__init__.py:
--------------------------------------------------------------------------------
1 | """Root package for Input/output."""
2 |
3 | from litmodels.io.cloud import download_model_files, upload_model_files # noqa: F401
4 | from litmodels.io.gateway import download_model, load_model, save_model, upload_model
5 |
6 | __all__ = ["download_model", "upload_model", "upload_model_files", "load_model", "save_model"]
7 |
--------------------------------------------------------------------------------
/src/litmodels/io/cloud.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # http://www.apache.org/licenses/LICENSE-2.0
4 | #
5 | from pathlib import Path
6 | from typing import TYPE_CHECKING, Dict, List, Optional, Union
7 |
8 | from lightning_sdk.lightning_cloud.env import LIGHTNING_CLOUD_URL
9 | from lightning_sdk.models import _extend_model_name_with_teamspace, _parse_org_teamspace_model_version
10 | from lightning_sdk.models import delete_model as sdk_delete_model
11 | from lightning_sdk.models import download_model as sdk_download_model
12 | from lightning_sdk.models import upload_model as sdk_upload_model
13 |
14 | import litmodels
15 |
16 | if TYPE_CHECKING:
17 | from lightning_sdk.models import UploadedModelInfo
18 |
19 |
20 | _SHOWED_MODEL_LINKS = []
21 |
22 |
23 | def _print_model_link(name: str, verbose: Union[bool, int]) -> None:
24 | """Print a link to the uploaded model.
25 |
26 | Args:
27 | name: Name of the model.
28 | verbose: Whether to print the link:
29 |
30 | - If set to 0, no link will be printed.
31 | - If set to 1, the link will be printed only once.
32 | - If set to 2, the link will be printed every time.
33 | """
34 | name = _extend_model_name_with_teamspace(name)
35 | org_name, teamspace_name, model_name, _ = _parse_org_teamspace_model_version(name)
36 |
37 | url = f"{LIGHTNING_CLOUD_URL}/{org_name}/{teamspace_name}/models/{model_name}"
38 | msg = f"Model uploaded successfully. Link to the model: '{url}'"
39 | if int(verbose) > 1:
40 | print(msg)
41 | elif url not in _SHOWED_MODEL_LINKS:
42 | print(msg)
43 | _SHOWED_MODEL_LINKS.append(url)
44 |
45 |
46 | def upload_model_files(
47 | name: str,
48 | path: Union[str, Path, List[Union[str, Path]]],
49 | progress_bar: bool = True,
50 | cloud_account: Optional[str] = None,
51 | verbose: Union[bool, int] = 1,
52 | metadata: Optional[Dict[str, str]] = None,
53 | ) -> "UploadedModelInfo":
54 | """Upload a local checkpoint file to the model store.
55 |
56 | Args:
57 | name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
58 | where entity is either your username or the name of an organization you are part of.
59 | path: Path to the model file to upload.
60 | progress_bar: Whether to show a progress bar for the upload.
61 | cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined
62 | automatically.
63 | verbose: Whether to print a link to the uploaded model. If set to 0, no link will be printed.
64 | metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
65 |
66 | """
67 | if not metadata:
68 | metadata = {}
69 | metadata.update({"litModels": litmodels.__version__})
70 | info = sdk_upload_model(
71 | name=name,
72 | path=path,
73 | progress_bar=progress_bar,
74 | cloud_account=cloud_account,
75 | metadata=metadata,
76 | )
77 | if verbose:
78 | _print_model_link(name, verbose)
79 | return info
80 |
81 |
82 | def download_model_files(
83 | name: str,
84 | download_dir: Union[str, Path] = ".",
85 | progress_bar: bool = True,
86 | ) -> Union[str, List[str]]:
87 | """Download a checkpoint from the model store.
88 |
89 | Args:
90 | name: Name of the model to download. Must be in the format 'organization/teamspace/modelname'
91 | where entity is either your username or the name of an organization you are part of.
92 | download_dir: A path to directory where the model should be downloaded. Defaults
93 | to the current working directory.
94 | progress_bar: Whether to show a progress bar for the download.
95 |
96 | Returns:
97 | The absolute path to the downloaded model file or folder.
98 | """
99 | return sdk_download_model(
100 | name=name,
101 | download_dir=download_dir,
102 | progress_bar=progress_bar,
103 | )
104 |
105 |
106 | def _list_available_teamspaces() -> Dict[str, dict]:
107 | """List available teamspaces for the authenticated user.
108 |
109 | Returns:
110 | Dict with teamspace names as keys and their details as values.
111 | """
112 | from lightning_sdk.api import OrgApi, UserApi
113 | from lightning_sdk.utils import resolve as sdk_resolvers
114 |
115 | org_api = OrgApi()
116 | user = sdk_resolvers._get_authed_user()
117 | teamspaces = {}
118 | for ts in UserApi()._get_all_teamspace_memberships(""):
119 | if ts.owner_type == "organization":
120 | org = org_api._get_org_by_id(ts.owner_id)
121 | teamspaces[f"{org.name}/{ts.name}"] = {"name": ts.name, "org": org.name}
122 | elif ts.owner_type == "user": # todo: check also the name
123 | teamspaces[f"{user.name}/{ts.name}"] = {"name": ts.name, "user": user}
124 | else:
125 | raise RuntimeError(f"Unknown organization type {ts.organization_type}")
126 | return teamspaces
127 |
128 |
129 | def delete_model_version(
130 | name: str,
131 | version: Optional[str] = None,
132 | ) -> None:
133 | """Delete a model version from the model store.
134 |
135 | Args:
136 | name: Name of the model to delete. Must be in the format 'organization/teamspace/modelname'
137 | where entity is either your username or the name of an organization you are part of.
138 | version: Version of the model to delete. If not provided, all versions will be deleted.
139 | """
140 | sdk_delete_model(name=f"{name}:{version}")
141 |
--------------------------------------------------------------------------------
/src/litmodels/io/gateway.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | from pathlib import Path
4 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
5 |
6 | from litmodels.io.cloud import download_model_files, upload_model_files
7 | from litmodels.io.utils import _KERAS_AVAILABLE, _PYTORCH_AVAILABLE, dump_pickle, load_pickle
8 |
9 | if _PYTORCH_AVAILABLE:
10 | import torch
11 |
12 | if _KERAS_AVAILABLE:
13 | from tensorflow import keras
14 |
15 | if TYPE_CHECKING:
16 | from lightning_sdk.models import UploadedModelInfo
17 |
18 |
19 | def upload_model(
20 | name: str,
21 | model: Union[str, Path],
22 | progress_bar: bool = True,
23 | cloud_account: Optional[str] = None,
24 | verbose: Union[bool, int] = 1,
25 | metadata: Optional[Dict[str, str]] = None,
26 | ) -> "UploadedModelInfo":
27 | """Upload a checkpoint to the model store.
28 |
29 | Args:
30 | name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
31 | where entity is either your username or the name of an organization you are part of.
32 | model: The model to upload. Can be a path to a checkpoint file or a folder.
33 | progress_bar: Whether to show a progress bar for the upload.
34 | cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined
35 | automatically.
36 | verbose: Whether to print some additional information about the uploaded model.
37 | metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
38 |
39 | """
40 | if not isinstance(model, (str, Path)):
41 | raise ValueError(
42 | "The `model` argument should be a path to a file or folder, not an python object."
43 | " For smooth integrations with PyTorch model, Lightning model and many more, use `save_model` instead."
44 | )
45 |
46 | return upload_model_files(
47 | path=model,
48 | name=name,
49 | progress_bar=progress_bar,
50 | cloud_account=cloud_account,
51 | verbose=verbose,
52 | metadata=metadata,
53 | )
54 |
55 |
56 | def save_model(
57 | name: str,
58 | model: Union["torch.nn.Module", Any],
59 | progress_bar: bool = True,
60 | cloud_account: Optional[str] = None,
61 | staging_dir: Optional[str] = None,
62 | verbose: Union[bool, int] = 1,
63 | metadata: Optional[Dict[str, str]] = None,
64 | ) -> "UploadedModelInfo":
65 | """Upload a checkpoint to the model store.
66 |
67 | Args:
68 | name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
69 | where entity is either your username or the name of an organization you are part of.
70 | model: The model to upload. Can be a PyTorch model, or a Lightning model a.
71 | progress_bar: Whether to show a progress bar for the upload.
72 | cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined
73 | automatically.
74 | staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
75 | be created and used.
76 | verbose: Whether to print some additional information about the uploaded model.
77 | metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
78 |
79 | """
80 | if isinstance(model, (str, Path)):
81 | raise ValueError(
82 | "The `model` argument should be a PyTorch model or a Lightning model, not a path to a file."
83 | " With file or folder path use `upload_model` instead."
84 | )
85 |
86 | if not staging_dir:
87 | staging_dir = tempfile.mkdtemp()
88 | # if LightningModule and isinstance(model, LightningModule):
89 | # path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
90 | # model.save_checkpoint(path)
91 | if _PYTORCH_AVAILABLE and isinstance(model, torch.jit.ScriptModule):
92 | path = os.path.join(staging_dir, f"{model.__class__.__name__}.ts")
93 | model.save(path)
94 | elif _PYTORCH_AVAILABLE and isinstance(model, torch.nn.Module):
95 | path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth")
96 | torch.save(model.state_dict(), path)
97 | elif _KERAS_AVAILABLE and isinstance(model, keras.models.Model):
98 | path = os.path.join(staging_dir, f"{model.__class__.__name__}.keras")
99 | model.save(path)
100 | else:
101 | path = os.path.join(staging_dir, f"{model.__class__.__name__}.pkl")
102 | dump_pickle(model=model, path=path)
103 |
104 | if not metadata:
105 | metadata = {}
106 | metadata.update({"litModels.integration": "save_model"})
107 |
108 | return upload_model(
109 | model=path,
110 | name=name,
111 | progress_bar=progress_bar,
112 | cloud_account=cloud_account,
113 | verbose=verbose,
114 | metadata=metadata,
115 | )
116 |
117 |
118 | def download_model(
119 | name: str,
120 | download_dir: Union[str, Path] = ".",
121 | progress_bar: bool = True,
122 | ) -> Union[str, List[str]]:
123 | """Download a checkpoint from the model store.
124 |
125 | Args:
126 | name: Name of the model to download. Must be in the format 'organization/teamspace/modelname'
127 | where entity is either your username or the name of an organization you are part of.
128 | download_dir: A path to directory where the model should be downloaded. Defaults
129 | to the current working directory.
130 | progress_bar: Whether to show a progress bar for the download.
131 |
132 | Returns:
133 | The absolute path to the downloaded model file or folder.
134 | """
135 | return download_model_files(
136 | name=name,
137 | download_dir=download_dir,
138 | progress_bar=progress_bar,
139 | )
140 |
141 |
142 | def load_model(name: str, download_dir: str = ".") -> Any:
143 | """Download a model from the model store and load it into memory.
144 |
145 | Args:
146 | name: Name of the model to download. Must be in the format 'organization/teamspace/modelname'
147 | where entity is either your username or the name of an organization you are part of.
148 | download_dir: A path to directory where the model should be downloaded. Defaults
149 | to the current working directory.
150 |
151 | Returns:
152 | The loaded model.
153 | """
154 | download_paths = download_model(name=name, download_dir=download_dir)
155 | # filter out all Markdown, TXT and RST files
156 | download_paths = [p for p in download_paths if Path(p).suffix.lower() not in {".md", ".txt", ".rst"}]
157 | if len(download_paths) > 1:
158 | raise NotImplementedError("Downloaded model with multiple files is not supported yet.")
159 | model_path = Path(download_dir) / download_paths[0]
160 | if model_path.suffix.lower() == ".ts":
161 | return torch.jit.load(model_path)
162 | if model_path.suffix.lower() == ".keras":
163 | return keras.models.load_model(model_path)
164 | if model_path.suffix.lower() == ".pkl":
165 | return load_pickle(path=model_path)
166 | raise NotImplementedError(f"Loading model from {model_path.suffix} is not supported yet.")
167 |
--------------------------------------------------------------------------------
/src/litmodels/io/utils.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from pathlib import Path
3 | from typing import Any, Union
4 |
5 | from lightning_utilities import module_available
6 | from lightning_utilities.core.imports import RequirementCache
7 |
8 | _JOBLIB_AVAILABLE = module_available("joblib")
9 | _PYTORCH_AVAILABLE = module_available("torch")
10 | _TENSORFLOW_AVAILABLE = module_available("tensorflow")
11 | _KERAS_AVAILABLE = RequirementCache("tensorflow >=2.0.0")
12 |
13 | if _JOBLIB_AVAILABLE:
14 | import joblib
15 |
16 |
17 | def dump_pickle(model: Any, path: Union[str, Path]) -> None:
18 | """Dump a model to a pickle file.
19 |
20 | Args:
21 | model: The model to be pickled.
22 | path: The path where the model will be saved.
23 | """
24 | if _JOBLIB_AVAILABLE:
25 | joblib.dump(model, filename=path, compress=7)
26 | else:
27 | with open(path, "wb") as fp:
28 | pickle.dump(model, fp, protocol=pickle.HIGHEST_PROTOCOL)
29 |
30 |
31 | def load_pickle(path: Union[str, Path]) -> Any:
32 | """Load a model from a pickle file.
33 |
34 | Args:
35 | path: The path to the pickle file.
36 |
37 | Returns:
38 | The unpickled model.
39 | """
40 | if _JOBLIB_AVAILABLE:
41 | return joblib.load(path)
42 | with open(path, "rb") as fp:
43 | return pickle.load(fp)
44 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """Configure local testing."""
2 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """Pytest configuration for integration tests."""
2 |
3 | import pytest
4 |
5 | from litmodels.integrations.checkpoints import get_model_manager
6 |
7 |
8 | @pytest.fixture(autouse=True)
9 | def reset_model_manager():
10 | get_model_manager.cache_clear()
11 | # Optionally, call it once to initialize immediately
12 | return get_model_manager()
13 |
--------------------------------------------------------------------------------
/tests/integrations/__init__.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from litmodels.integrations.imports import (
4 | _LIGHTNING_AVAILABLE,
5 | _LIGHTNING_GREATER_EQUAL_2_5_1,
6 | _PYTORCHLIGHTNING_AVAILABLE,
7 | _PYTORCHLIGHTNING_GREATER_EQUAL_2_5_1,
8 | )
9 |
10 | _SKIP_IF_LIGHTNING_MISSING = pytest.mark.skipif(not _LIGHTNING_AVAILABLE, reason="Lightning not available")
11 | _SKIP_IF_LIGHTNING_BELLOW_2_5_1 = pytest.mark.skipif(
12 | not _LIGHTNING_GREATER_EQUAL_2_5_1, reason="Lightning without integration introduced in 2.5.1"
13 | )
14 | _SKIP_IF_PYTORCHLIGHTNING_MISSING = pytest.mark.skipif(
15 | not _PYTORCHLIGHTNING_AVAILABLE, reason="PyTorch Lightning not available"
16 | )
17 | _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1 = pytest.mark.skipif(
18 | not _PYTORCHLIGHTNING_GREATER_EQUAL_2_5_1, reason="PyTorch Lightning without integration introduced in 2.5.1"
19 | )
20 |
21 | LIT_ORG = "lightning-ai"
22 | LIT_TEAMSPACE = "OSS | litModels"
23 |
--------------------------------------------------------------------------------
/tests/integrations/test_checkpoints.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import re
3 | from unittest import mock
4 |
5 | import pytest
6 |
7 | import litmodels
8 | from tests.integrations import _SKIP_IF_LIGHTNING_MISSING, _SKIP_IF_PYTORCHLIGHTNING_MISSING
9 |
10 |
11 | @pytest.mark.parametrize(
12 | "importing",
13 | [
14 | pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_MISSING),
15 | pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_MISSING),
16 | ],
17 | )
18 | @pytest.mark.parametrize(
19 | "model_name", [None, "org-name/teamspace/model-name", "model-in-studio", "model-user-only-project"]
20 | )
21 | @pytest.mark.parametrize("clear_all_local", [True, False])
22 | @pytest.mark.parametrize("keep_all_uploaded", [True, False])
23 | @mock.patch("litmodels.io.cloud.sdk_delete_model")
24 | @mock.patch("litmodels.io.cloud.sdk_upload_model")
25 | @mock.patch("litmodels.integrations.checkpoints.Auth")
26 | def test_lightning_checkpoint_callback(
27 | mock_auth,
28 | mock_upload_model,
29 | mock_delete_model,
30 | monkeypatch,
31 | importing,
32 | model_name,
33 | clear_all_local,
34 | keep_all_uploaded,
35 | tmp_path,
36 | ):
37 | if importing == "lightning":
38 | from lightning.pytorch import Trainer
39 | from lightning.pytorch.callbacks import ModelCheckpoint
40 | from lightning.pytorch.demos.boring_classes import BoringModel
41 |
42 | from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint
43 | elif importing == "pytorch_lightning":
44 | from pytorch_lightning import Trainer
45 | from pytorch_lightning.callbacks import ModelCheckpoint
46 | from pytorch_lightning.demos.boring_classes import BoringModel
47 |
48 | from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint
49 |
50 | # Validate inheritance
51 | assert issubclass(LitModelCheckpoint, ModelCheckpoint)
52 |
53 | ckpt_args = {"clear_all_local": clear_all_local, "keep_all_uploaded": keep_all_uploaded}
54 | if model_name:
55 | ckpt_args.update({"model_registry": model_name})
56 |
57 | all_model_registry = {
58 | "org-name/teamspace/model-name": {"org": "org-name", "teamspace": "teamspace", "model": "model-name"},
59 | "model-in-studio": {"org": "my-org", "teamspace": "dream-team", "model": "model-in-studio"},
60 | "model-user-only-project": {"org": "my-org", "teamspace": "default-ts", "model": "model-user-only-project"},
61 | }
62 | expected_boring_model = "BoringModel_20250102-1213"
63 | expected_model_registry = all_model_registry.get(
64 | model_name,
65 | {"org": "org-name", "teamspace": "teamspace", "model": expected_boring_model},
66 | )
67 | expected_org = expected_model_registry["org"]
68 | expected_teamspace = expected_model_registry["teamspace"]
69 | expected_model = expected_model_registry["model"]
70 | mock_upload_model.return_value.name = f"{expected_org}/{expected_teamspace}/{expected_model}"
71 | monkeypatch.setattr(
72 | "litmodels.integrations.checkpoints.LitModelCheckpointMixin.default_model_name",
73 | mock.MagicMock(return_value=expected_boring_model),
74 | )
75 | if model_name is None or model_name == "model-in-studio":
76 | mock_teamspace = mock.Mock(owner=mock.Mock())
77 | mock_teamspace.owner.name = expected_org
78 | mock_teamspace.name = expected_teamspace
79 |
80 | monkeypatch.setattr(
81 | "litmodels.integrations.checkpoints._resolve_teamspace", mock.MagicMock(return_value=mock_teamspace)
82 | )
83 | elif model_name == "model-user-only-project":
84 | monkeypatch.setattr("litmodels.integrations.checkpoints._resolve_teamspace", mock.MagicMock(return_value=None))
85 | monkeypatch.setattr(
86 | "litmodels.integrations.checkpoints._list_available_teamspaces",
87 | mock.MagicMock(return_value={f"{expected_org}/{expected_teamspace}": {}}),
88 | )
89 |
90 | # mocking the trainer delete checkpoint removal
91 | mock_remove_ckpt = mock.Mock()
92 | # setting the Trainer and custom checkpointing
93 | trainer = Trainer(
94 | max_epochs=2,
95 | callbacks=LitModelCheckpoint(**ckpt_args),
96 | )
97 | trainer.strategy.remove_checkpoint = mock_remove_ckpt
98 | trainer.fit(BoringModel())
99 |
100 | assert mock_auth.call_count == 1
101 | assert mock_upload_model.call_args_list == [
102 | mock.call(
103 | name=f"{expected_org}/{expected_teamspace}/{expected_model}:{v}",
104 | path=mock.ANY,
105 | progress_bar=True,
106 | cloud_account=None,
107 | metadata={"litModels.integration": LitModelCheckpoint.__name__, "litModels": litmodels.__version__},
108 | )
109 | for v in ("epoch=0-step=64", "epoch=1-step=128")
110 | ]
111 | expected_local_removals = 2 if clear_all_local else 1
112 | assert mock_remove_ckpt.call_count == expected_local_removals
113 |
114 | expected_cloud_removals = 0 if keep_all_uploaded else 1
115 | assert mock_delete_model.call_count == expected_cloud_removals
116 | if expected_cloud_removals:
117 | mock_delete_model.assert_called_once_with(
118 | name=f"{expected_org}/{expected_teamspace}/{expected_model}:epoch=0-step=64"
119 | )
120 |
121 | # Verify paths match the expected pattern
122 | for call_args in mock_upload_model.call_args_list:
123 | path = call_args[1]["path"]
124 | assert re.match(r".*[/\\]lightning_logs[/\\]version_\d+[/\\]checkpoints[/\\]epoch=\d+-step=\d+\.ckpt$", path)
125 |
126 |
127 | @pytest.mark.parametrize(
128 | "importing",
129 | [
130 | pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_MISSING),
131 | pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_MISSING),
132 | ],
133 | )
134 | @mock.patch("litmodels.integrations.checkpoints.Auth")
135 | def test_lightning_checkpointing_pickleable(mock_auth, importing):
136 | if importing == "lightning":
137 | from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint
138 | elif importing == "pytorch_lightning":
139 | from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint
140 |
141 | ckpt = LitModelCheckpoint(model_registry="org-name/teamspace/model-name")
142 | assert mock_auth.call_count == 1
143 | pickle.dumps(ckpt)
144 |
--------------------------------------------------------------------------------
/tests/integrations/test_duplicate.py:
--------------------------------------------------------------------------------
1 | import os
2 | from unittest import mock
3 |
4 | from litmodels.integrations.duplicate import duplicate_hf_model
5 |
6 |
7 | @mock.patch("litmodels.integrations.duplicate.snapshot_download")
8 | @mock.patch("litmodels.integrations.duplicate.upload_model_files")
9 | def test_duplicate_hf_model(mock_upload_model, mock_snapshot_download, tmp_path):
10 | """Verify that the HF model can be duplicated to the teamspace"""
11 |
12 | hf_model = "google/t5-efficient-tiny"
13 | # model name with random hash
14 | model_name = f"litmodels_hf_model+{os.urandom(8).hex()}"
15 | duplicate_hf_model(hf_model=hf_model, lit_model=model_name, local_workdir=str(tmp_path))
16 |
17 | mock_snapshot_download.assert_called_with(
18 | repo_id=hf_model,
19 | revision="main",
20 | repo_type="model",
21 | local_dir=tmp_path / hf_model.replace("/", "_"),
22 | local_dir_use_symlinks=True,
23 | ignore_patterns=[".cache*"],
24 | max_workers=os.cpu_count(),
25 | )
26 | mock_upload_model.assert_called_with(
27 | name=f"{model_name}",
28 | path=tmp_path / hf_model.replace("/", "_"),
29 | metadata={"hf_model": hf_model, "litModels.integration": "duplicate_hf_model"},
30 | verbose=1,
31 | )
32 |
--------------------------------------------------------------------------------
/tests/integrations/test_mixins.py:
--------------------------------------------------------------------------------
1 | from unittest import mock
2 |
3 | import pytest
4 | import torch
5 | from torch import nn
6 |
7 | from litmodels.integrations.mixins import PickleRegistryMixin, PyTorchRegistryMixin
8 |
9 |
10 | class DummyModel(PickleRegistryMixin):
11 | def __init__(self, value):
12 | self.value = value
13 |
14 | def __eq__(self, other):
15 | return isinstance(other, DummyModel) and self.value == other.value
16 |
17 |
18 | @mock.patch("litmodels.integrations.mixins.upload_model_files")
19 | @mock.patch("litmodels.integrations.mixins.download_model_files")
20 | def test_pickle_push_and_pull(mock_download_model, mock_upload_model, tmp_path):
21 | # Create an instance of DummyModel and call push_to_registry.
22 | dummy = DummyModel(42)
23 | dummy.upload_model(version="v1", temp_folder=str(tmp_path))
24 | # The expected registry name is "dummy_model:v1" and the file should be placed in the temp folder.
25 | expected_path = tmp_path / "DummyModel.pkl"
26 | mock_upload_model.assert_called_once_with(
27 | name="DummyModel:v1", path=expected_path, metadata={"litModels.integration": "PickleRegistryMixin"}
28 | )
29 |
30 | # Set the mock to return the full path to the pickle file.
31 | mock_download_model.return_value = ["DummyModel.pkl"]
32 | # Call pull_from_registry and load the DummyModel instance.
33 | loaded_dummy = DummyModel.download_model(name="dummy_model", version="v1", temp_folder=str(tmp_path))
34 | # Verify that the unpickled instance has the expected value.
35 | assert loaded_dummy.value == 42
36 |
37 |
38 | class DummyTorchModelFirst(PyTorchRegistryMixin, nn.Module):
39 | def __init__(self, input_size: int, output_size: int = 10):
40 | # PyTorchRegistryMixin.__init__ will capture these arguments
41 | super().__init__()
42 | self.fc = nn.Linear(input_size, output_size)
43 |
44 | def forward(self, x):
45 | x = x.view(x.size(0), -1)
46 | return self.fc(x)
47 |
48 |
49 | class DummyTorchModelSecond(nn.Module, PyTorchRegistryMixin):
50 | def __init__(self, input_size: int, output_size: int = 10):
51 | PyTorchRegistryMixin.__init__(input_size, output_size)
52 | super().__init__()
53 | self.fc = nn.Linear(input_size, output_size)
54 |
55 | def forward(self, x):
56 | x = x.view(x.size(0), -1)
57 | return self.fc(x)
58 |
59 |
60 | @pytest.mark.parametrize("torch_class", [DummyTorchModelFirst, DummyTorchModelSecond])
61 | @mock.patch("litmodels.integrations.mixins.upload_model_files")
62 | @mock.patch("litmodels.integrations.mixins.download_model_files")
63 | def test_pytorch_push_and_pull(mock_download_model, mock_upload_model, torch_class, tmp_path):
64 | # Create an instance, push the model and record its forward output.
65 | dummy = torch_class(784)
66 | dummy.eval()
67 | input_tensor = torch.randn(1, 784)
68 | output_before = dummy(input_tensor)
69 |
70 | torch_file = f"{dummy.__class__.__name__}.pth"
71 | torch.save(dummy.state_dict(), tmp_path / torch_file)
72 | json_file = f"{dummy.__class__.__name__}__init_kwargs.json"
73 | json_path = tmp_path / json_file
74 | with open(json_path, "w") as fp:
75 | fp.write('{"input_size": 784, "output_size": 10}')
76 |
77 | dummy.upload_model(temp_folder=str(tmp_path))
78 | mock_upload_model.assert_called_once_with(
79 | name=torch_class.__name__,
80 | path=[tmp_path / f"{torch_class.__name__}.pth", json_path],
81 | metadata={"litModels.integration": "PyTorchRegistryMixin"},
82 | )
83 |
84 | # Prepare mocking for pull_from_registry.
85 | mock_download_model.return_value = [torch_file, json_file]
86 | loaded_dummy = torch_class.download_model(name=torch_class.__name__, temp_folder=str(tmp_path))
87 | loaded_dummy.eval()
88 | output_after = loaded_dummy(input_tensor)
89 |
90 | assert isinstance(loaded_dummy, torch_class)
91 | # Compare the outputs as a verification.
92 | assert torch.allclose(output_before, output_after), "Loaded model output differs from original."
93 |
--------------------------------------------------------------------------------
/tests/integrations/test_real_cloud.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 | from contextlib import redirect_stdout
4 | from io import StringIO
5 | from typing import Optional
6 |
7 | import pytest
8 | import torch
9 | from lightning_sdk import Teamspace
10 | from lightning_sdk.lightning_cloud.rest_client import GridRestClient
11 | from lightning_sdk.utils.resolve import _resolve_teamspace
12 |
13 | from litmodels import download_model, load_model, save_model, upload_model
14 | from litmodels.integrations.duplicate import duplicate_hf_model
15 | from litmodels.integrations.mixins import PickleRegistryMixin, PyTorchRegistryMixin
16 | from litmodels.io.cloud import _list_available_teamspaces
17 | from litmodels.io.utils import _KERAS_AVAILABLE
18 | from tests.integrations import (
19 | _SKIP_IF_LIGHTNING_BELLOW_2_5_1,
20 | _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1,
21 | LIT_ORG,
22 | LIT_TEAMSPACE,
23 | )
24 |
25 |
26 | def _prepare_variables(test_name: str) -> tuple[Teamspace, str, str]:
27 | model_name = f"ci-test_integrations_{test_name}+{os.urandom(8).hex()}"
28 | teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None)
29 | org_team = f"{teamspace.owner.name}/{teamspace.name}"
30 | return teamspace, org_team, model_name
31 |
32 |
33 | def _cleanup_model(teamspace: Teamspace, model_name: str, expected_num_versions: Optional[int] = None) -> None:
34 | """Cleanup model from the teamspace."""
35 | client = GridRestClient()
36 | # cleaning created models as each test run shall have unique model name
37 | model = client.models_store_get_model_by_name(
38 | project_owner_name=teamspace.owner.name,
39 | project_name=teamspace.name,
40 | model_name=model_name,
41 | )
42 | if expected_num_versions is not None:
43 | versions = client.models_store_list_model_versions(project_id=model.project_id, model_id=model.id)
44 | assert expected_num_versions == len(versions.versions)
45 | client.models_store_delete_model(project_id=model.project_id, model_id=model.id)
46 |
47 |
48 | @pytest.mark.cloud
49 | @pytest.mark.parametrize(
50 | "in_studio",
51 | [False, pytest.param(True, marks=pytest.mark.skipif(platform.system() != "Linux", reason="Studio is just Linux"))],
52 | )
53 | def test_upload_download_model(in_studio, monkeypatch, tmp_path):
54 | """Verify that the model is uploaded to the teamspace"""
55 | if in_studio:
56 | # mock env variables as it would run in studio
57 | monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG)
58 | monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE)
59 |
60 | # create a dummy file
61 | file_path = tmp_path / "dummy.txt"
62 | with open(file_path, "w") as f:
63 | f.write("dummy")
64 |
65 | # model name with random hash
66 | teamspace, org_team, model_name = _prepare_variables("upload_download")
67 | model_registry = f"{org_team}/{model_name}" if not in_studio else model_name
68 |
69 | out = StringIO()
70 | with redirect_stdout(out):
71 | upload_model(name=model_registry, model=file_path)
72 |
73 | # validate the output
74 | assert (
75 | f"Model uploaded successfully. Link to the model: 'https://lightning.ai/{org_team}/models/{model_name}'"
76 | ) in out.getvalue()
77 |
78 | os.remove(file_path)
79 | assert not os.path.isfile(file_path)
80 |
81 | model_files = download_model(name=model_registry, download_dir=tmp_path)
82 | assert model_files == ["dummy.txt"]
83 | for file in model_files:
84 | assert os.path.isfile(os.path.join(tmp_path, file))
85 |
86 | # CLEANING
87 | _cleanup_model(teamspace, model_name, expected_num_versions=1)
88 |
89 |
90 | @pytest.mark.parametrize(
91 | "importing",
92 | [
93 | pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_BELLOW_2_5_1),
94 | pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1),
95 | ],
96 | )
97 | @pytest.mark.parametrize(
98 | "in_studio",
99 | [
100 | False,
101 | pytest.param(True, marks=pytest.mark.skipif(platform.system() == "Windows", reason="studio is not Windows")),
102 | ],
103 | )
104 | @pytest.mark.cloud
105 | def test_lightning_default_checkpointing(importing, in_studio, monkeypatch, tmp_path):
106 | if in_studio:
107 | # mock env variables as it would run in studio
108 | monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG)
109 | monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE)
110 |
111 | if importing == "lightning":
112 | from lightning import Trainer
113 | from lightning.pytorch.demos.boring_classes import BoringModel
114 | elif importing == "pytorch_lightning":
115 | from pytorch_lightning import Trainer
116 | from pytorch_lightning.demos.boring_classes import BoringModel
117 |
118 | # model name with random hash
119 | teamspace, org_team, model_name = _prepare_variables("default_checkpoint")
120 | model_registry = f"{org_team}/{model_name}" if not in_studio else model_name
121 |
122 | trainer = Trainer(
123 | max_epochs=2,
124 | default_root_dir=tmp_path,
125 | model_registry=model_registry,
126 | )
127 | trainer.fit(BoringModel())
128 |
129 | # CLEANING
130 | _cleanup_model(teamspace, model_name, expected_num_versions=2)
131 |
132 |
133 | @pytest.mark.parametrize("trainer_method", ["fit", "validate", "test", "predict"])
134 | @pytest.mark.parametrize(
135 | "registry", ["registry", "registry:version:v1", "registry:", "registry::version:v1"]
136 | )
137 | @pytest.mark.parametrize(
138 | "importing",
139 | [
140 | pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_BELLOW_2_5_1),
141 | pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1),
142 | ],
143 | )
144 | @pytest.mark.parametrize(
145 | "in_studio",
146 | [
147 | False,
148 | pytest.param(True, marks=pytest.mark.skipif(platform.system() == "Windows", reason="studio is not Windows")),
149 | ],
150 | )
151 | @pytest.mark.cloud
152 | def test_lightning_plain_resume(trainer_method, registry, importing, in_studio, tmp_path, monkeypatch):
153 | if importing == "lightning":
154 | from lightning import Trainer
155 | from lightning.pytorch.demos.boring_classes import BoringModel
156 | elif importing == "pytorch_lightning":
157 | from pytorch_lightning import Trainer
158 | from pytorch_lightning.demos.boring_classes import BoringModel
159 |
160 | if in_studio:
161 | # mock env variables as it would run in studio
162 | monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG)
163 | monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE)
164 |
165 | trainer = Trainer(max_epochs=1, limit_train_batches=50, limit_val_batches=20, default_root_dir=tmp_path)
166 | trainer.fit(BoringModel())
167 | checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
168 |
169 | # model name with random hash
170 | teamspace, org_team, model_name = _prepare_variables(f"resume_{trainer_method}")
171 | model_registry = f"{org_team}/{model_name}" if not in_studio else model_name
172 | upload_model(model=checkpoint_path, name=model_registry)
173 | expected_num_versions = 1
174 |
175 | trainer_kwargs = {"model_registry": model_registry} if "" not in registry else {}
176 | trainer = Trainer(
177 | max_epochs=2,
178 | default_root_dir=tmp_path,
179 | limit_train_batches=10,
180 | limit_val_batches=10,
181 | limit_test_batches=10,
182 | limit_predict_batches=10,
183 | **trainer_kwargs,
184 | )
185 | registry = registry.replace("", model_registry)
186 | if trainer_method == "fit":
187 | trainer.fit(BoringModel(), ckpt_path=registry)
188 | if trainer_kwargs:
189 | expected_num_versions += 1
190 | elif trainer_method == "validate":
191 | trainer.validate(BoringModel(), ckpt_path=registry)
192 | elif trainer_method == "test":
193 | trainer.test(BoringModel(), ckpt_path=registry)
194 | elif trainer_method == "predict":
195 | trainer.predict(BoringModel(), ckpt_path=registry)
196 | else:
197 | raise ValueError(f"Unknown trainer method: {trainer_method}")
198 |
199 | # CLEANING
200 | _cleanup_model(teamspace, model_name, expected_num_versions=expected_num_versions)
201 |
202 |
203 | @pytest.mark.parametrize(
204 | "importing",
205 | [
206 | pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_BELLOW_2_5_1),
207 | pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1),
208 | ],
209 | )
210 | @pytest.mark.cloud
211 | def test_lightning_checkpoint_ddp(importing, tmp_path):
212 | if importing == "lightning":
213 | from lightning import Trainer
214 | from lightning.pytorch.demos.boring_classes import BoringModel
215 | elif importing == "pytorch_lightning":
216 | from pytorch_lightning import Trainer
217 | from pytorch_lightning.demos.boring_classes import BoringModel
218 |
219 | # model name with random hash
220 | teamspace, org_team, model_name = _prepare_variables("checkpoint_resume")
221 | trainer_args = {
222 | "default_root_dir": tmp_path,
223 | "accelerator": "cpu",
224 | "strategy": "ddp_spawn",
225 | "devices": 4,
226 | "model_registry": f"{org_team}/{model_name}",
227 | }
228 |
229 | trainer = Trainer(max_epochs=2, **trainer_args)
230 | trainer.fit(BoringModel())
231 |
232 | # FIXME: seems like barrier is not respected in the test, but in real life it correctly waits for all GPUs
233 | # trainer = Trainer(max_epochs=5, **trainer_args)
234 | # trainer.fit(BoringModel(), ckpt_path="registry")
235 |
236 | # CLEANING
237 | _cleanup_model(teamspace, model_name, expected_num_versions=2)
238 |
239 |
240 | class DummyModel(PickleRegistryMixin):
241 | def __init__(self, value):
242 | self.value = value
243 |
244 |
245 | @pytest.mark.cloud
246 | def test_pickle_mixin_push_and_pull():
247 | # model name with random hash
248 | teamspace, org_team, model_name = _prepare_variables("pickle_mixin")
249 | model_registry = f"{org_team}/{model_name}"
250 |
251 | # Create an instance of DummyModel and call push_to_registry.
252 | dummy = DummyModel(42)
253 | dummy.upload_model(model_registry)
254 |
255 | # Call pull_from_registry and load the DummyModel instance.
256 | loaded_dummy = DummyModel.download_model(model_registry)
257 | # Verify that the unpickled instance has the expected value.
258 | assert isinstance(loaded_dummy, DummyModel)
259 | assert loaded_dummy.value == 42
260 |
261 | # CLEANING
262 | _cleanup_model(teamspace, model_name, expected_num_versions=1)
263 |
264 |
265 | # This is a dummy model for PyTorch that uses the PyTorchRegistryMixin.
266 | # This mixin has to be first in the inheritance order.
267 | # Otherwise, `PyTorchRegistryMixin.__init__` need to be called explicitly.
268 | class DummyTorchModel(PyTorchRegistryMixin, torch.nn.Module):
269 | def __init__(self, input_size: int, output_size: int = 10):
270 | # PyTorchRegistryMixin.__init__ will capture these arguments
271 | super().__init__()
272 | self.fc = torch.nn.Linear(input_size, output_size)
273 |
274 | def forward(self, x):
275 | x = x.view(x.size(0), -1)
276 | return self.fc(x)
277 |
278 |
279 | @pytest.mark.cloud
280 | def test_pytorch_mixin_push_and_pull():
281 | # model name with random hash
282 | teamspace, org_team, model_name = _prepare_variables("torch_mixin")
283 | model_registry = f"{org_team}/{model_name}"
284 |
285 | # Create an instance, push the model and record its forward output.
286 | dummy = DummyTorchModel(784)
287 | dummy.eval()
288 | input_tensor = torch.randn(1, 784)
289 | output_before = dummy(input_tensor)
290 |
291 | dummy.upload_model(model_registry)
292 |
293 | loaded_dummy = DummyTorchModel.download_model(model_registry)
294 | loaded_dummy.eval()
295 | output_after = loaded_dummy(input_tensor)
296 |
297 | assert isinstance(loaded_dummy, DummyTorchModel)
298 | # Compare the outputs as a verification.
299 | assert torch.allclose(output_before, output_after), "Loaded model output differs from original."
300 |
301 | # CLEANING
302 | _cleanup_model(teamspace, model_name, expected_num_versions=1)
303 |
304 |
305 | @pytest.mark.cloud
306 | def test_duplicate_real_hf_model(tmp_path):
307 | """Verify that the HF model can be duplicated to the teamspace"""
308 |
309 | # model name with random hash
310 | model_name = f"litmodels_hf_model+{os.urandom(8).hex()}"
311 | teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None)
312 | org_team = f"{teamspace.owner.name}/{teamspace.name}"
313 |
314 | duplicate_hf_model(hf_model="google/t5-efficient-tiny", lit_model=f"{org_team}/{model_name}")
315 |
316 | client = GridRestClient()
317 | model = client.models_store_get_model_by_name(
318 | project_owner_name=teamspace.owner.name,
319 | project_name=teamspace.name,
320 | model_name=model_name,
321 | )
322 | client.models_store_delete_model(project_id=teamspace.id, model_id=model.id)
323 |
324 |
325 | @pytest.mark.cloud
326 | def test_list_available_teamspaces():
327 | teams = _list_available_teamspaces()
328 | assert len(teams) > 0
329 | # using sanitized teamspace name
330 | assert f"{LIT_ORG}/oss-litmodels" in teams
331 |
332 |
333 | @pytest.mark.cloud
334 | @pytest.mark.skipif(
335 | not _KERAS_AVAILABLE,
336 | reason="TensorFlow Keras is not supported on Windows for now.",
337 | )
338 | def test_save_load_tensorflow_keras(tmp_path):
339 | from tensorflow import keras
340 |
341 | # Define the model
342 | model = keras.Sequential([
343 | keras.layers.Dense(10, input_shape=(784,), name="dense_1"),
344 | keras.layers.Dense(10, name="dense_2"),
345 | ])
346 |
347 | # Compile the model
348 | model.compile(optimizer="adam", loss="categorical_crossentropy")
349 |
350 | # model name with random hash
351 | teamspace, org_team, model_name = _prepare_variables("tf-keras")
352 | save_model(f"{org_team}/{model_name}", model=model)
353 |
354 | # Load the model
355 | model_ = load_model(f"{org_team}/{model_name}", download_dir=str(tmp_path))
356 |
357 | # validate the model
358 | assert isinstance(model_, type(model))
359 | _cleanup_model(teamspace, model_name, expected_num_versions=1)
360 |
--------------------------------------------------------------------------------
/tests/test_io_cloud.py:
--------------------------------------------------------------------------------
1 | import os
2 | from contextlib import nullcontext
3 | from unittest import mock
4 |
5 | import joblib
6 | import pytest
7 | import torch
8 | import torch.jit as torch_jit
9 | from sklearn import svm
10 | from torch.nn import Module
11 |
12 | import litmodels
13 | from litmodels import download_model, load_model, save_model
14 | from litmodels.io import upload_model_files
15 | from litmodels.io.utils import _KERAS_AVAILABLE
16 | from tests.integrations import LIT_ORG, LIT_TEAMSPACE
17 |
18 |
19 | @pytest.mark.parametrize("name", ["/too/many/slashes", "org/model", "model-name"])
20 | @pytest.mark.parametrize("in_studio", [True, False])
21 | def test_upload_wrong_model_name(name, in_studio, monkeypatch):
22 | if in_studio:
23 | # mock env variables as it would run in studio
24 | monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG)
25 | monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE)
26 | monkeypatch.setattr("lightning_sdk.organization.Organization", mock.MagicMock)
27 | monkeypatch.setattr("lightning_sdk.teamspace.Teamspace", mock.MagicMock)
28 | monkeypatch.setattr("lightning_sdk.teamspace.TeamspaceApi", mock.MagicMock)
29 | monkeypatch.setattr("lightning_sdk.models._get_teamspace", mock.MagicMock)
30 |
31 | in_studio_only_name = in_studio and name == "model-name"
32 | with (
33 | pytest.raises(ValueError, match=r".*organization/teamspace/model.*")
34 | if not in_studio_only_name
35 | else nullcontext()
36 | ):
37 | upload_model_files(path="path/to/checkpoint", name=name)
38 |
39 |
40 | @pytest.mark.parametrize("name", ["/too/many/slashes", "org/model", "model-name"])
41 | @pytest.mark.parametrize("in_studio", [True, False])
42 | def test_download_wrong_model_name(name, in_studio, monkeypatch):
43 | if in_studio:
44 | # mock env variables as it would run in studio
45 | monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG)
46 | monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE)
47 | monkeypatch.setattr("lightning_sdk.organization.Organization", mock.MagicMock)
48 | monkeypatch.setattr("lightning_sdk.teamspace.Teamspace", mock.MagicMock)
49 | monkeypatch.setattr("lightning_sdk.models.TeamspaceApi", mock.MagicMock)
50 | in_studio_only_name = in_studio and name == "model-name"
51 | with (
52 | pytest.raises(ValueError, match=r".*organization/teamspace/model.*")
53 | if not in_studio_only_name
54 | else nullcontext()
55 | ):
56 | download_model(name=name)
57 |
58 |
59 | @pytest.mark.parametrize(
60 | ("model", "model_path", "verbose"),
61 | [
62 | # ("path/to/checkpoint", "path/to/checkpoint", False),
63 | # (BoringModel(), "%s/BoringModel.ckpt"),
64 | (torch_jit.script(Module()), f"%s{os.path.sep}RecursiveScriptModule.ts", True),
65 | (Module(), f"%s{os.path.sep}Module.pth", True),
66 | (svm.SVC(), f"%s{os.path.sep}SVC.pkl", 1),
67 | ],
68 | )
69 | @mock.patch("litmodels.io.cloud.sdk_upload_model")
70 | def test_upload_model(mock_upload_model, tmp_path, model, model_path, verbose):
71 | mock_upload_model.return_value.name = "org-name/teamspace/model-name"
72 |
73 | # The lit-logger function is just a wrapper around the SDK function
74 | save_model(
75 | model=model,
76 | name="org-name/teamspace/model-name",
77 | cloud_account="cluster_id",
78 | staging_dir=str(tmp_path),
79 | verbose=verbose,
80 | )
81 | expected_path = model_path % str(tmp_path) if "%" in model_path else model_path
82 | mock_upload_model.assert_called_once_with(
83 | path=expected_path,
84 | name="org-name/teamspace/model-name",
85 | cloud_account="cluster_id",
86 | progress_bar=True,
87 | metadata={"litModels": litmodels.__version__, "litModels.integration": "save_model"},
88 | )
89 |
90 |
91 | @mock.patch("litmodels.io.cloud.sdk_download_model")
92 | def test_download_model(mock_download_model):
93 | # The lit-logger function is just a wrapper around the SDK function
94 | download_model(
95 | name="org-name/teamspace/model-name",
96 | download_dir="where/to/download",
97 | )
98 | mock_download_model.assert_called_once_with(
99 | name="org-name/teamspace/model-name", download_dir="where/to/download", progress_bar=True
100 | )
101 |
102 |
103 | @mock.patch("litmodels.io.cloud.sdk_download_model")
104 | def test_load_model_pickle(mock_download_model, tmp_path):
105 | # create a dummy model file
106 | model_file = tmp_path / "dummy_model.pkl"
107 | test_data = svm.SVC()
108 | joblib.dump(test_data, model_file)
109 | mock_download_model.return_value = [str(model_file.name)]
110 |
111 | # The lit-logger function is just a wrapper around the SDK function
112 | model = load_model(
113 | name="org-name/teamspace/model-name",
114 | download_dir=str(tmp_path),
115 | )
116 | mock_download_model.assert_called_once_with(
117 | name="org-name/teamspace/model-name", download_dir=str(tmp_path), progress_bar=True
118 | )
119 | assert isinstance(model, svm.SVC)
120 |
121 |
122 | @mock.patch("litmodels.io.cloud.sdk_download_model")
123 | def test_load_model_torch_jit(mock_download_model, tmp_path):
124 | # create a dummy model file
125 | model_file = tmp_path / "dummy_model.ts"
126 | test_data = torch_jit.script(Module())
127 | test_data.save(model_file)
128 | mock_download_model.return_value = [str(model_file.name)]
129 |
130 | # The lit-logger function is just a wrapper around the SDK function
131 | model = load_model(
132 | name="org-name/teamspace/model-name",
133 | download_dir=str(tmp_path),
134 | )
135 | mock_download_model.assert_called_once_with(
136 | name="org-name/teamspace/model-name", download_dir=str(tmp_path), progress_bar=True
137 | )
138 | assert isinstance(model, torch.jit.ScriptModule)
139 |
140 |
141 | @pytest.mark.skipif(not _KERAS_AVAILABLE, reason="TensorFlow/Keras is not available")
142 | @mock.patch("litmodels.io.cloud.sdk_download_model")
143 | def test_load_model_tf_keras(mock_download_model, tmp_path):
144 | from tensorflow import keras
145 |
146 | # create a dummy model file
147 | model_file = tmp_path / "dummy_model.keras"
148 | # Define the model
149 | model = keras.Sequential([
150 | keras.layers.Dense(10, input_shape=(784,), name="dense_1"),
151 | keras.layers.Dense(10, name="dense_2"),
152 | ])
153 | model.compile(optimizer="adam", loss="categorical_crossentropy")
154 | model.save(model_file)
155 | # prepare mocked SDK download function
156 | mock_download_model.return_value = [str(model_file.name)]
157 |
158 | # The lit-logger function is just a wrapper around the SDK function
159 | model = load_model(
160 | name="org-name/teamspace/model-name",
161 | download_dir=str(tmp_path),
162 | )
163 | mock_download_model.assert_called_once_with(
164 | name="org-name/teamspace/model-name", download_dir=str(tmp_path), progress_bar=True
165 | )
166 | assert isinstance(model, keras.models.Model)
167 |
--------------------------------------------------------------------------------