├── .codecov.yml ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── documentation.md │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml ├── stale.yml └── workflows │ ├── ci-checks.yml │ ├── ci-cloud.yml │ ├── ci-testing.yml │ ├── cleanup-caches.yml │ ├── docs-build.yml │ ├── greetings.yml │ ├── label-conflicts.yml │ └── release-pypi.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── CHANGELOG.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── _requirements ├── _ci.txt ├── _docs.txt ├── extra.txt ├── test.txt └── typing.txt ├── docs ├── .build_docs.sh ├── Makefile ├── make.bat └── source │ ├── _static │ ├── copybutton.js │ └── images │ │ ├── icon.svg │ │ ├── logo-large.svg │ │ ├── logo-small.svg │ │ ├── logo.png │ │ └── logo.svg │ ├── _templates │ └── theme_variables.jinja │ ├── conf.py │ └── index.rst ├── examples ├── demo-tensorflow-keras.py ├── demo-upload-download.py ├── resume-lightning-training.py ├── train-model-and-simple-save.py ├── train-model-with-lightning-callback.py └── train-model-with-lightning-logger.py ├── pyproject.toml ├── requirements.txt ├── setup.py ├── src └── litmodels │ ├── __about__.py │ ├── __init__.py │ ├── demos │ └── __init__.py │ ├── integrations │ ├── __init__.py │ ├── checkpoints.py │ ├── duplicate.py │ ├── imports.py │ └── mixins.py │ └── io │ ├── __init__.py │ ├── cloud.py │ ├── gateway.py │ └── utils.py └── tests ├── __init__.py ├── conftest.py ├── integrations ├── __init__.py ├── test_checkpoints.py ├── test_duplicate.py ├── test_mixins.py └── test_real_cloud.py └── test_io_cloud.py /.codecov.yml: -------------------------------------------------------------------------------- 1 | # see https://docs.codecov.io/docs/codecov-yaml 2 | # Validation check: 3 | # $ curl --data-binary @.codecov.yml https://codecov.io/validate 4 | 5 | # https://docs.codecov.io/docs/codecovyml-reference 6 | codecov: 7 | bot: "codecov-io" 8 | strict_yaml_branch: "yaml-config" 9 | require_ci_to_pass: yes 10 | notify: 11 | # after_n_builds: 2 12 | wait_for_ci: yes 13 | 14 | coverage: 15 | precision: 0 # 2 = xx.xx%, 0 = xx% 16 | round: nearest # how coverage is rounded: down/up/nearest 17 | range: 40...100 # custom range of coverage colors from red -> yellow -> green 18 | status: 19 | # https://codecov.readme.io/v1.0/docs/commit-status 20 | project: 21 | default: 22 | informational: true 23 | target: 99% # specify the target coverage for each commit status 24 | threshold: 30% # allow this little decrease on project 25 | # https://github.com/codecov/support/wiki/Filtering-Branches 26 | # branches: main 27 | if_ci_failed: error 28 | # https://github.com/codecov/support/wiki/Patch-Status 29 | patch: 30 | default: 31 | informational: true 32 | target: 50% # specify the target "X%" coverage to hit 33 | # threshold: 50% # allow this much decrease on patch 34 | changes: false 35 | 36 | # https://docs.codecov.com/docs/github-checks#disabling-github-checks-patch-annotations 37 | github_checks: 38 | annotations: false 39 | 40 | parsers: 41 | gcov: 42 | branch_detection: 43 | conditional: true 44 | loop: true 45 | macro: false 46 | method: false 47 | javascript: 48 | enable_partials: false 49 | 50 | comment: 51 | layout: header, diff 52 | require_changes: false 53 | behavior: default # update if exists else create new 54 | # branches: * 55 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Each line is a file pattern followed by one or more owners. 2 | 3 | # These owners will be the default owners for everything in the repo. Unless a later match takes precedence, 4 | # @global-owner1 and @global-owner2 will be requested for review when someone opens a pull request. 5 | * @borda @ethanwharris @justusschock 6 | 7 | # CI/CD and configs 8 | /.github/ @borda 9 | *.yml @borda 10 | 11 | /README.md @williamfalcon 12 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug, help wanted 6 | assignees: '' 7 | --- 8 | 9 | ## 🐛 Bug 10 | 11 | 12 | 13 | ### To Reproduce 14 | 15 | Steps to reproduce the behavior: 16 | 17 | 1. Go to '...' 18 | 2. Run '....' 19 | 3. Scroll down to '....' 20 | 4. See error 21 | 22 | 23 | 24 | #### Code sample 25 | 26 | 28 | 29 | ### Expected behavior 30 | 31 | 32 | 33 | ### Environment 34 | 35 | - PyTorch Version (e.g., 1.0): 36 | - OS (e.g., Linux): 37 | - How you installed PyTorch (`conda`, `pip`, source): 38 | - Build command you used (if compiling from source): 39 | - Python version: 40 | - CUDA/cuDNN version: 41 | - GPU models and configuration: 42 | - Any other relevant information: 43 | 44 | ### Additional context 45 | 46 | 47 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Typos and doc fixes 3 | about: Typos and doc fixes 4 | title: '' 5 | labels: documentation 6 | assignees: '' 7 | --- 8 | 9 | ## 📚 Documentation 10 | 11 | For typos and doc fixes, please go ahead and: 12 | 13 | 1. Create an issue. 14 | 2. Fix the typo. 15 | 3. Submit a PR. 16 | 17 | Thanks! 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement, help wanted 6 | assignees: '' 7 | --- 8 | 9 | ## 🚀 Feature 10 | 11 | 12 | 13 | ### Motivation 14 | 15 | 16 | 17 | ### Pitch 18 | 19 | 20 | 21 | ### Alternatives 22 | 23 | 24 | 25 | ### Additional context 26 | 27 | 28 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 |
2 | Before submitting 3 | 4 | - [ ] Was this discussed/agreed via a Github issue? (no need for typos and docs improvements) 5 | - [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/pytorch-lightning/blob/main/.github/CONTRIBUTING.md), Pull Request section? 6 | - [ ] Did you make sure to update the docs? 7 | - [ ] Did you write any new necessary tests? 8 | 9 |
10 | 11 | ## What does this PR do? 12 | 13 | Fixes # (issue). 14 | 15 | ## PR review 16 | 17 | Anyone in the community is free to review the PR once the tests have passed. 18 | If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged. 19 | 20 | ## Did you have fun? 21 | 22 | Make sure you had fun coding 🙃 23 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Basic dependabot.yml file with 2 | # minimum configuration for two package managers 3 | 4 | version: 2 5 | updates: 6 | # Enable version updates for python 7 | - package-ecosystem: "pip" 8 | # Look for a `requirements` in the `root` directory 9 | directory: "/" 10 | # Check for updates once a week 11 | schedule: 12 | interval: "daily" 13 | # Labels on pull requests for version updates only 14 | labels: 15 | - "dependencies" 16 | pull-request-branch-name: 17 | # Separate sections of the branch name with a hyphen 18 | # for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1` 19 | separator: "-" 20 | # Allow up to 5 open pull requests for pip dependencies 21 | open-pull-requests-limit: 5 22 | 23 | # Enable version updates for GitHub Actions 24 | - package-ecosystem: "github-actions" 25 | directory: "/" 26 | # Check for updates once a week 27 | schedule: 28 | interval: "monthly" 29 | # Labels on pull requests for version updates only 30 | labels: ["CI"] 31 | pull-request-branch-name: 32 | # Separate sections of the branch name with a hyphen for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1` 33 | separator: "-" 34 | # Allow up to 5 open pull requests for GitHub Actions 35 | open-pull-requests-limit: 5 36 | groups: 37 | GHA-updates: 38 | patterns: 39 | - "*" 40 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # https://github.com/marketplace/stale 2 | 3 | # Number of days of inactivity before an issue becomes stale 4 | daysUntilStale: 60 5 | # Number of days of inactivity before a stale issue is closed 6 | daysUntilClose: 14 7 | # Issues with these labels will never be considered stale 8 | exemptLabels: 9 | - pinned 10 | - security 11 | # Label to use when marking an issue as stale 12 | staleLabel: won't fix 13 | # Comment to post when marking an issue as stale. Set to `false` to disable 14 | markComment: > 15 | This issue has been automatically marked as stale because it has not had 16 | recent activity. It will be closed if no further activity occurs. Thank you 17 | for your contributions. 18 | # Comment to post when closing a stale issue. Set to `false` to disable 19 | closeComment: false 20 | 21 | # Set to true to ignore issues in a project (defaults to false) 22 | exemptProjects: true 23 | # Set to true to ignore issues in a milestone (defaults to false) 24 | exemptMilestones: true 25 | # Set to true to ignore issues with an assignee (defaults to false) 26 | exemptAssignees: true 27 | -------------------------------------------------------------------------------- /.github/workflows/ci-checks.yml: -------------------------------------------------------------------------------- 1 | name: General checks 2 | 3 | on: 4 | push: 5 | branches: [main, "release/*"] 6 | pull_request: 7 | branches: [main, "release/*"] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 11 | cancel-in-progress: ${{ ! (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }} 12 | 13 | jobs: 14 | #check-precommit: 15 | # uses: Lightning-AI/utilities/.github/workflows/check-precommit.yml@main 16 | 17 | check-typing: 18 | uses: Lightning-AI/utilities/.github/workflows/check-typing.yml@main 19 | with: 20 | actions-ref: main 21 | extra-typing: "typing" 22 | 23 | check-schema: 24 | uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@main 25 | with: 26 | azure-dir: "" 27 | 28 | check-package: 29 | uses: Lightning-AI/utilities/.github/workflows/check-package.yml@main 30 | with: 31 | actions-ref: main 32 | import-name: "litmodels" 33 | artifact-name: dist-packages-${{ github.sha }} 34 | testing-matrix: | 35 | { 36 | "os": ["ubuntu-latest", "macos-latest", "windows-latest"], 37 | "python-version": ["3.9", "3.12"] 38 | } 39 | 40 | # check-docs: 41 | # uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@main 42 | # with: 43 | # requirements-file: "_requirements/_docs.txt" 44 | -------------------------------------------------------------------------------- /.github/workflows/ci-cloud.yml: -------------------------------------------------------------------------------- 1 | name: Cloud integration 2 | 3 | # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | on: 5 | push: 6 | workflow_dispatch: 7 | schedule: 8 | - cron: "0 0 * * *" 9 | 10 | defaults: 11 | run: 12 | shell: bash 13 | 14 | jobs: 15 | access-secrets: 16 | # try to access secrets to ensure they are available 17 | # if not set job output to be false 18 | runs-on: ubuntu-latest 19 | outputs: 20 | secrets_available: ${{ steps.check_secrets.outputs.secrets_available }} 21 | steps: 22 | - name: Check secrets 23 | id: check_secrets 24 | run: | 25 | if [[ -z "${{ secrets.LIGHTNING_USER_ID }}" || -z "${{ secrets.LIGHTNING_API_KEY }}" ]]; then 26 | echo "Secrets are not set. Exiting..." 27 | echo "secrets_available=false" >> $GITHUB_OUTPUT 28 | else 29 | echo "Secrets are available." 30 | echo "secrets_available=true" >> $GITHUB_OUTPUT 31 | fi 32 | 33 | integration: 34 | needs: access-secrets 35 | if: needs.access-secrets.outputs.secrets_available == 'true' 36 | runs-on: ${{ matrix.os }} 37 | strategy: 38 | fail-fast: false 39 | matrix: 40 | os: ["ubuntu-22.04", "macOS-13", "windows-2022"] 41 | python-version: ["3.10"] 42 | 43 | # Timeout: https://stackoverflow.com/a/59076067/4521646 44 | timeout-minutes: 25 45 | env: 46 | TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html" 47 | 48 | steps: 49 | - uses: actions/checkout@v4 50 | - uses: actions/setup-python@v5 51 | with: 52 | python-version: ${{ matrix.python-version }} 53 | cache: "pip" 54 | 55 | - name: Install package & dependencies 56 | run: | 57 | pip --version 58 | pip install -e '.[test,extra]' -U -q --find-links $TORCH_URL 59 | pip list 60 | 61 | - name: Test integrations 62 | env: 63 | LIGHTNING_USER_ID: ${{ secrets.LIGHTNING_USER_ID }} 64 | LIGHTNING_API_KEY: ${{ secrets.LIGHTNING_API_KEY }} 65 | run: | 66 | coverage run --source litmodels -m pytest src tests -v -m cloud 67 | timeout-minutes: 15 68 | 69 | - name: Statistics 70 | run: | 71 | coverage report 72 | coverage xml 73 | 74 | - name: Upload coverage to Codecov 75 | uses: codecov/codecov-action@v5 76 | with: 77 | token: ${{ secrets.CODECOV_TOKEN }} 78 | files: ./coverage.xml 79 | flags: integrations 80 | env_vars: OS,PYTHON 81 | name: codecov-umbrella 82 | fail_ci_if_error: false 83 | 84 | integration-guardian: 85 | runs-on: ubuntu-latest 86 | needs: integration 87 | if: always() 88 | steps: 89 | - run: echo "${{ needs.integration.result }}" 90 | - name: failing... 91 | if: needs.integration.result == 'failure' 92 | run: exit 1 93 | - name: cancelled or skipped... 94 | if: contains(fromJSON('["cancelled", "skipped"]'), needs.integration.result) 95 | timeout-minutes: 1 96 | run: sleep 90 97 | 98 | # todo add job to report failing tests with cron 99 | -------------------------------------------------------------------------------- /.github/workflows/ci-testing.yml: -------------------------------------------------------------------------------- 1 | name: CI testing 2 | 3 | # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | on: # Trigger the workflow on push or pull request, but only for the main branch 5 | push: {} 6 | pull_request: 7 | branches: [main] 8 | 9 | defaults: 10 | run: 11 | shell: bash 12 | 13 | jobs: 14 | pytester: 15 | runs-on: ${{ matrix.os }} 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | os: ["ubuntu-24.04", "macOS-13", "windows-2022"] 20 | python-version: ["3.9", "3.12"] 21 | requires: ["latest"] 22 | dependency: ["lightning"] 23 | include: 24 | - { requires: "oldest", dependency: "lightning", os: "ubuntu-22.04", python-version: "3.9" } 25 | - { requires: "latest", dependency: "pytorch_lightning", os: "ubuntu-24.04", python-version: "3.12" } 26 | - { requires: "latest", dependency: "pytorch_lightning", os: "windows-2022", python-version: "3.12" } 27 | - { requires: "latest", dependency: "pytorch_lightning", os: "macOS-13", python-version: "3.12" } 28 | 29 | # Timeout: https://stackoverflow.com/a/59076067/4521646 30 | timeout-minutes: 35 31 | env: 32 | TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html" 33 | 34 | steps: 35 | - uses: actions/checkout@v4 36 | - name: Set up Python ${{ matrix.python-version }} 37 | uses: actions/setup-python@v5 38 | with: 39 | python-version: ${{ matrix.python-version }} 40 | cache: "pip" 41 | 42 | - name: Set min. dependencies 43 | if: matrix.requires == 'oldest' 44 | run: | 45 | pip install 'lightning-utilities[cli]' 46 | python -m lightning_utilities.cli requirements set-oldest --req_files='["requirements.txt"]' 47 | 48 | - name: Adjust requirements 49 | run: | 50 | pip install 'lightning-utilities[cli]' -U -q 51 | python -m lightning_utilities.cli requirements replace-pkg \ 52 | --old_package="lightning" \ 53 | --new_package="${{matrix.dependency}}" \ 54 | --req_files='["_requirements/extra.txt"]' 55 | cat _requirements/extra.txt 56 | 57 | - name: Install package & dependencies 58 | run: | 59 | set -e 60 | pip --version 61 | pip install -e '.[test,extra]' -U -q --find-links $TORCH_URL 62 | pip list 63 | # check that right package was installed 64 | python -c "import ${{matrix.dependency}}; print(${{matrix.dependency}}.__version__)" 65 | 66 | - name: Tests with mocks 67 | run: | 68 | coverage run --source litmodels -m pytest src tests -v -m "not cloud" 69 | 70 | - name: Statistics 71 | run: | 72 | coverage report 73 | coverage xml 74 | 75 | - name: Upload coverage to Codecov 76 | uses: codecov/codecov-action@v5 77 | continue-on-error: true 78 | with: 79 | token: ${{ secrets.CODECOV_TOKEN }} 80 | files: ./coverage.xml 81 | flags: unittests 82 | env_vars: OS,PYTHON 83 | name: codecov-umbrella 84 | fail_ci_if_error: false 85 | 86 | tests-guardian: 87 | runs-on: ubuntu-latest 88 | needs: pytester 89 | if: always() 90 | steps: 91 | - run: echo "${{ needs.pytester.result }}" 92 | - name: failing... 93 | if: needs.pytester.result == 'failure' 94 | run: exit 1 95 | - name: cancelled or skipped... 96 | if: contains(fromJSON('["cancelled", "skipped"]'), needs.pytester.result) 97 | timeout-minutes: 1 98 | run: sleep 90 99 | -------------------------------------------------------------------------------- /.github/workflows/cleanup-caches.yml: -------------------------------------------------------------------------------- 1 | # https://docs.github.com/en/actions/using-workflows/caching-dependencies-to-speed-up-workflows#force-deleting-cache-entries 2 | name: cleanup caches by a branch 3 | on: 4 | pull_request: 5 | types: [closed] 6 | 7 | jobs: 8 | pr-cleanup: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Check out code 12 | uses: actions/checkout@v4 13 | 14 | - name: Cleanup 15 | run: | 16 | gh extension install actions/gh-actions-cache 17 | 18 | REPO=${{ github.repository }} 19 | BRANCH="refs/pull/${{ github.event.pull_request.number }}/merge" 20 | 21 | echo "Fetching list of cache key" 22 | cacheKeysForPR=$(gh actions-cache list -R $REPO -B $BRANCH | cut -f 1 ) 23 | 24 | ## Setting this to not fail the workflow while deleting cache keys. 25 | set +e 26 | echo "Deleting caches..." 27 | for cacheKey in $cacheKeysForPR 28 | do 29 | gh actions-cache delete $cacheKey -R $REPO -B $BRANCH --confirm 30 | done 31 | echo "Done" 32 | env: 33 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 34 | -------------------------------------------------------------------------------- /.github/workflows/docs-build.yml: -------------------------------------------------------------------------------- 1 | name: "Build (& deploy) Docs" 2 | on: 3 | push: 4 | branches: [main] 5 | workflow_dispatch: 6 | 7 | jobs: 8 | build-docs: 9 | uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@main 10 | with: 11 | requirements-file: "_requirements/_docs.txt" 12 | 13 | # https://github.com/marketplace/actions/deploy-to-github-pages 14 | docs-deploy: 15 | needs: build-docs 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: actions/checkout@v4 # deploy needs git credentials 19 | - name: Download prepared docs 20 | uses: actions/download-artifact@v4 21 | with: 22 | name: docs-html-${{ github.sha }} 23 | path: docs/build/html 24 | 25 | - name: Deploy 🚀 26 | uses: JamesIves/github-pages-deploy-action@v4.7.3 27 | if: ${{ github.event_name == 'push' }} 28 | with: 29 | token: ${{ secrets.GITHUB_TOKEN }} 30 | branch: gh-pages # The branch the action should deploy to. 31 | folder: docs/build/html # The folder the action should deploy. 32 | clean: true # Automatically remove deleted files from the deploy branch 33 | target-folder: docs # If you'd like to push the contents of the deployment folder into a specific directory 34 | single-commit: true # you'd prefer to have a single commit on the deployment branch instead of full history 35 | -------------------------------------------------------------------------------- /.github/workflows/greetings.yml: -------------------------------------------------------------------------------- 1 | name: Greetings 2 | # https://github.com/marketplace/actions/first-interaction 3 | 4 | on: [issues] # pull_request 5 | 6 | jobs: 7 | greeting: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/first-interaction@v1 11 | with: 12 | repo-token: ${{ secrets.GITHUB_TOKEN }} 13 | issue-message: "Hi! thanks for your contribution!, great first issue!" 14 | pr-message: "Hey thanks for the input! Please give us a bit of time to review it!" 15 | -------------------------------------------------------------------------------- /.github/workflows/label-conflicts.yml: -------------------------------------------------------------------------------- 1 | name: Label conflicts 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request_target: 7 | types: ["synchronize", "reopened", "opened"] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }} 11 | cancel-in-progress: false 12 | 13 | jobs: 14 | triage-conflicts: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: mschilde/auto-label-merge-conflicts@591722e97f3c4142df3eca156ed0dcf2bcd362bd # Oct 25, 2021 18 | with: 19 | CONFLICT_LABEL_NAME: "has conflicts" 20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 21 | MAX_RETRIES: 3 22 | WAIT_MS: 5000 23 | -------------------------------------------------------------------------------- /.github/workflows/release-pypi.yml: -------------------------------------------------------------------------------- 1 | name: PyPI Release 2 | 3 | # https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | on: # Trigger the workflow on push or pull request, but only for the main branch 5 | push: 6 | branches: [main] 7 | release: 8 | types: [published] 9 | 10 | # based on https://github.com/pypa/gh-action-pypi-publish 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-22.04 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: actions/setup-python@v5 18 | with: 19 | python-version: "3.10" 20 | 21 | - name: Install dependencies 22 | run: pip install -U -r _requirements/_ci.txt 23 | - name: Build package 24 | run: python -m build 25 | - name: Check package 26 | run: twine check dist/* 27 | 28 | # We do this, since failures on test.pypi aren't that bad 29 | - name: Publish to Test PyPI 30 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' 31 | uses: pypa/gh-action-pypi-publish@v1.12.4 32 | with: 33 | user: __token__ 34 | password: ${{ secrets.TEST_PYPI_PASSWORD }} 35 | repository-url: https://test.pypi.org/legacy/ 36 | 37 | - name: Publish distribution 📦 to PyPI 38 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' 39 | uses: pypa/gh-action-pypi-publish@v1.12.4 40 | with: 41 | user: __token__ 42 | password: ${{ secrets.pypi_password }} 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Sphinx documentation 58 | docs/_build/ 59 | docs/source/api/ 60 | docs/source/*.md 61 | docs/source/readme.rst 62 | 63 | # PyBuilder 64 | target/ 65 | 66 | # Jupyter Notebook 67 | .ipynb_checkpoints 68 | 69 | # IPython 70 | profile_default/ 71 | ipython_config.py 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 77 | __pypackages__/ 78 | 79 | # Celery stuff 80 | celerybeat-schedule 81 | celerybeat.pid 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # mypy 96 | .mypy_cache/ 97 | .dmypy.json 98 | dmypy.json 99 | 100 | # Pyre type checker 101 | .pyre/ 102 | 103 | # PyCharm 104 | .idea/ 105 | 106 | # Lightning logs 107 | lightning_logs 108 | *.gz 109 | .DS_Store 110 | .*_submit.py 111 | 112 | # Ruff 113 | .ruff_cache/ 114 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | ci: 5 | autofix_prs: true 6 | autoupdate_commit_msg: "[pre-commit.ci] pre-commit suggestions" 7 | autoupdate_schedule: "monthly" 8 | # submodules: true 9 | 10 | repos: 11 | - repo: https://github.com/pre-commit/pre-commit-hooks 12 | rev: v5.0.0 13 | hooks: 14 | - id: end-of-file-fixer 15 | - id: trailing-whitespace 16 | exclude: README.md 17 | - id: check-case-conflict 18 | - id: check-yaml 19 | - id: check-toml 20 | - id: check-json 21 | - id: check-added-large-files 22 | - id: check-docstring-first 23 | - id: detect-private-key 24 | 25 | - repo: https://github.com/pre-commit/pygrep-hooks 26 | rev: v1.10.0 27 | hooks: 28 | - id: python-use-type-annotations 29 | 30 | - repo: https://github.com/codespell-project/codespell 31 | rev: v2.4.1 32 | hooks: 33 | - id: codespell 34 | additional_dependencies: [tomli] 35 | #args: ["--write-changes"] 36 | 37 | - repo: https://github.com/pre-commit/mirrors-prettier 38 | rev: v3.1.0 39 | hooks: 40 | - id: prettier 41 | files: \.(json|yml|yaml|toml) 42 | # https://prettier.io/docs/en/options.html#print-width 43 | args: ["--print-width=120"] 44 | 45 | - repo: https://github.com/executablebooks/mdformat 46 | rev: 0.7.22 47 | hooks: 48 | - id: mdformat 49 | args: ["--number"] 50 | additional_dependencies: 51 | - mdformat-gfm 52 | - mdformat-black 53 | - mdformat_frontmatter 54 | exclude: CHANGELOG.md 55 | 56 | - repo: https://github.com/astral-sh/ruff-pre-commit 57 | rev: v0.11.12 58 | hooks: 59 | - id: ruff 60 | args: ["--fix"] 61 | - id: ruff-format 62 | - id: ruff 63 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | sphinx: 10 | configuration: docs/source/conf.py 11 | fail_on_warning: true 12 | 13 | # Optionally build your docs in additional formats such as PDF and ePub 14 | formats: all 15 | 16 | # Optionally set the version of Python and requirements required to build your docs 17 | python: 18 | version: 3.7 19 | install: 20 | - requirements: _requirements/_docs.txt 21 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [Unreleased] - YYYY-MM-DD 9 | 10 | ### Added 11 | 12 | ### Changed 13 | 14 | ### Fixed 15 | 16 | ### Removed 17 | 18 | ### Deprecated 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2021-2024 Lightning-AI team 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Manifest syntax https://docs.python.org/2/distutils/sourcedist.html 2 | graft wheelhouse 3 | 4 | recursive-exclude __pycache__ *.py[cod] *.orig 5 | 6 | # Include the README and CHANGELOG 7 | include *.md 8 | recursive-include src *.md 9 | 10 | # Include the license file 11 | include LICENSE 12 | 13 | # Exclude build configs 14 | exclude *.sh 15 | exclude *.toml 16 | exclude *.svg 17 | exclude *.yml 18 | exclude *.yaml 19 | 20 | # exclude tests from package 21 | recursive-exclude tests * 22 | recursive-exclude site * 23 | exclude tests 24 | 25 | # Exclude the documentation files 26 | recursive-exclude docs * 27 | exclude docs 28 | 29 | # Include the Requirements 30 | include requirements.txt 31 | recursive-include _requirements *.tx;t 32 | 33 | # Exclude Makefile 34 | exclude Makefile 35 | 36 | prune .git 37 | prune .github 38 | prune .circleci 39 | prune notebook* 40 | prune temp* 41 | prune test* 42 | prune benchmark* 43 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test clean docs 2 | 3 | # to imitate SLURM set only single node 4 | export SLURM_LOCALID=0 5 | # assume you have installed need packages 6 | export SPHINX_MOCK_REQUIREMENTS=0 7 | 8 | test: clean 9 | pip install -q -r requirements.txt 10 | pip install -q -r _requirements/test.txt 11 | 12 | # use this to run tests 13 | python -m coverage run --source litmodels -m pytest src tests -v --flake8 14 | python -m coverage report 15 | 16 | docs: clean 17 | pip install . --quiet -r _requirements/docs.txt 18 | python -m sphinx -b html -W --keep-going docs/source docs/build 19 | 20 | clean: 21 | # clean all temp runs 22 | rm -rf $(shell find . -name "mlruns") 23 | rm -rf .mypy_cache 24 | rm -rf .pytest_cache 25 | rm -rf ./docs/build 26 | rm -rf ./docs/source/**/generated 27 | rm -rf ./docs/source/api 28 | rm -rf ./src/*.egg-info 29 | rm -rf ./build 30 | rm -rf ./dist 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Save, share and host AI model checkpoints Lightning fast ⚡ 4 | 5 | Lightning 6 | 7 |
8 | 9 | ______________________________________________________________________ 10 | 11 | Save, load, host, and share models without slowing down training. 12 | **LitModels** minimizes training slowdowns from checkpoint saving. Share public links on Lightning AI or your own cloud with enterprise-grade access controls. 13 | 14 |
15 | 16 |
 17 | ✅ Checkpoint without slowing training.  ✅ Granular access controls.           
 18 | ✅ Load models anywhere.                 ✅ Host on Lightning or your own cloud.
 19 | 
20 | 21 | [![Discord](https://img.shields.io/discord/1077906959069626439?label=Get%20help%20on%20Discord)](https://discord.gg/WajDThKAur) 22 | ![CI testing](https://github.com/Lightning-AI/LitModels/actions/workflows/ci-testing.yml/badge.svg?event=push) 23 | ![Cloud integration](https://github.com/Lightning-AI/LitModels/actions/workflows/ci-cloud.yml/badge.svg?event=push) 24 | [![codecov](https://codecov.io/gh/Lightning-AI/LitModels/graph/badge.svg?token=MQ0PN2cxKo)](https://codecov.io/gh/Lightning-AI/LitModels) 25 | [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/LitModels/blob/main/LICENSE) 26 | 27 |
28 | Quick start • 29 | Examples • 30 | Features • 31 | Performance • 32 | Community • 33 | Docs 34 |
35 | 36 |
37 | 38 | # Quick start 39 | 40 | Install LitModels via pip: 41 | 42 | ```bash 43 | pip install litmodels 44 | ``` 45 | 46 | Toy example ([see real examples](#examples)): 47 | 48 | ```python 49 | import litmodels as lm 50 | import torch 51 | 52 | # save a model 53 | model = torch.nn.Module() 54 | lm.save_model(model=model, name="model-name") 55 | 56 | # load a model 57 | model = lm.load_model(name="model-name") 58 | ``` 59 | 60 | # Examples 61 | 62 |
63 | PyTorch 64 | 65 | Save model: 66 | 67 | ```python 68 | import torch 69 | from litmodels import save_model 70 | 71 | model = torch.nn.Module() 72 | save_model(model=model, name="your_org/your_team/torch-model") 73 | ``` 74 | 75 | Load model: 76 | 77 | ```python 78 | from litmodels import load_model 79 | 80 | model_ = load_model(name="your_org/your_team/torch-model") 81 | ``` 82 | 83 |
84 | 85 |
86 | PyTorch Lightning 87 | 88 | Save model: 89 | 90 | ```python 91 | from lightning import Trainer 92 | from litmodels import upload_model 93 | from litmodels.demos import BoringModel 94 | 95 | # Configure Lightning Trainer 96 | trainer = Trainer(max_epochs=2) 97 | # Define the model and train it 98 | trainer.fit(BoringModel()) 99 | 100 | # Upload the best model to cloud storage 101 | checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path") 102 | # Define the model name - this should be unique to your model 103 | upload_model(model=checkpoint_path, name="//") 104 | ``` 105 | 106 | Load model: 107 | 108 | ```python 109 | from lightning import Trainer 110 | from litmodels import download_model 111 | from litmodels.demos import BoringModel 112 | 113 | # Load the model from cloud storage 114 | checkpoint_path = download_model( 115 | # Define the model name and version - this needs to be unique to your model 116 | name="//:", 117 | download_dir="my_models", 118 | ) 119 | print(f"model: {checkpoint_path}") 120 | 121 | # Train the model with extended training period 122 | trainer = Trainer(max_epochs=4) 123 | trainer.fit(BoringModel(), ckpt_path=checkpoint_path) 124 | ``` 125 | 126 |
127 | 128 |
129 | TensorFlow / Keras 130 | 131 | Save model: 132 | 133 | ```python 134 | from tensorflow import keras 135 | 136 | from litmodels import save_model 137 | 138 | # Define the model 139 | model = keras.Sequential( 140 | [ 141 | keras.layers.Dense(10, input_shape=(784,), name="dense_1"), 142 | keras.layers.Dense(10, name="dense_2"), 143 | ] 144 | ) 145 | 146 | # Compile the model 147 | model.compile(optimizer="adam", loss="categorical_crossentropy") 148 | 149 | # Save the model 150 | save_model("lightning-ai/jirka/sample-tf-keras-model", model=model) 151 | ``` 152 | 153 | Load model: 154 | 155 | ```python 156 | from litmodels import load_model 157 | 158 | model_ = load_model( 159 | "lightning-ai/jirka/sample-tf-keras-model", download_dir="./my-model" 160 | ) 161 | ``` 162 | 163 |
164 | 165 |
166 | SKLearn 167 | 168 | Save model: 169 | 170 | ```python 171 | from sklearn import datasets, model_selection, svm 172 | from litmodels import save_model 173 | 174 | # Load example dataset 175 | iris = datasets.load_iris() 176 | X, y = iris.data, iris.target 177 | 178 | # Split dataset into training and test sets 179 | X_train, X_test, y_train, y_test = model_selection.train_test_split( 180 | X, y, test_size=0.2, random_state=42 181 | ) 182 | 183 | # Train a simple SVC model 184 | model = svm.SVC() 185 | model.fit(X_train, y_train) 186 | 187 | # Upload the saved model using litmodels 188 | save_model(model=model, name="your_org/your_team/sklearn-svm-model") 189 | ``` 190 | 191 | Use model: 192 | 193 | ```python 194 | from litmodels import load_model 195 | 196 | # Download and load the model file from cloud storage 197 | model = load_model( 198 | name="your_org/your_team/sklearn-svm-model", download_dir="my_models" 199 | ) 200 | 201 | # Example: run inference with the loaded model 202 | sample_input = [[5.1, 3.5, 1.4, 0.2]] 203 | prediction = model.predict(sample_input) 204 | print(f"Prediction: {prediction}") 205 | ``` 206 | 207 |
208 | 209 | # Features 210 | 211 |
212 | PyTorch Lightning Callback 213 | 214 | Enhance your training process with an automatic checkpointing callback that uploads the model at the end of each epoch. 215 | 216 | ```python 217 | import torch.utils.data as data 218 | import torchvision as tv 219 | from lightning import Trainer 220 | from litmodels.integrations import LightningModelCheckpoint 221 | from litmodels.demos import BoringModel 222 | 223 | dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor()) 224 | train, val = data.random_split(dataset, [55000, 5000]) 225 | 226 | trainer = Trainer( 227 | max_epochs=2, 228 | callbacks=[ 229 | LightningModelCheckpoint( 230 | # Define the model name - this should be unique to your model 231 | model_registry="//", 232 | ) 233 | ], 234 | ) 235 | trainer.fit( 236 | BoringModel(), 237 | data.DataLoader(train, batch_size=256), 238 | data.DataLoader(val, batch_size=256), 239 | ) 240 | ``` 241 | 242 |
243 | 244 |
245 | Save any Python class as a checkpoint 246 | 247 | Mixin classes streamline model management in Python by modularizing reusable functionalities like saving/loading, enabling consistent, conflict-free, and maintainable code across multiple models. 248 | 249 | **Save model:** 250 | 251 | ```python 252 | from litmodels.integrations.mixins import PickleRegistryMixin 253 | 254 | 255 | class MyModel(PickleRegistryMixin): 256 | def __init__(self, param1, param2): 257 | self.param1 = param1 258 | self.param2 = param2 259 | # Your model initialization code 260 | ... 261 | 262 | 263 | # Create and push a model instance 264 | model = MyModel(param1=42, param2="hello") 265 | model.upload_model(name="my-org/my-team/my-model") 266 | ``` 267 | 268 | Load model: 269 | 270 | ```python 271 | loaded_model = MyModel.download_model(name="my-org/my-team/my-model") 272 | ``` 273 | 274 |
275 | 276 |
277 | Save custom PyTorch models 278 | 279 | Mixin classes centralize serialization logic, eliminating redundant code and ensuring consistent, error-free model persistence across projects. 280 | The `download_model` method bypasses constructor arguments entirely, reconstructing the model directly from the registry with pre-configured architecture and weights, eliminating initialization mismatches. 281 | 282 | Save model: 283 | 284 | ```python 285 | import torch 286 | from litmodels.integrations.mixins import PyTorchRegistryMixin 287 | 288 | 289 | # Important: PyTorchRegistryMixin must be first in the inheritance order 290 | class MyTorchModel(PyTorchRegistryMixin, torch.nn.Module): 291 | def __init__(self, input_size, hidden_size=128): 292 | super().__init__() 293 | self.linear = torch.nn.Linear(input_size, hidden_size) 294 | self.activation = torch.nn.ReLU() 295 | 296 | def forward(self, x): 297 | return self.activation(self.linear(x)) 298 | 299 | 300 | # Create and push the model 301 | model = MyTorchModel(input_size=784) 302 | model.upload_model(name="my-org/my-team/torch-model") 303 | ``` 304 | 305 | Use the model: 306 | 307 | ```python 308 | # Pull the model with the same architecture 309 | loaded_model = MyTorchModel.download_model(name="my-org/my-team/torch-model") 310 | ``` 311 | 312 |
313 | 314 | # Performance 315 | 316 | 319 | 320 | # Community 321 | 322 | 💬 [Get help on Discord](https://discord.com/invite/XncpTy7DSt)\ 323 | 📋 [License: Apache 2.0](https://github.com/Lightning-AI/litModels/blob/main/LICENSE) 324 | -------------------------------------------------------------------------------- /_requirements/_ci.txt: -------------------------------------------------------------------------------- 1 | build ==1.2.* 2 | twine ==6.1.* 3 | -------------------------------------------------------------------------------- /_requirements/_docs.txt: -------------------------------------------------------------------------------- 1 | sphinx >=6.0,<7.0 2 | myst-parser >=2.0.0 3 | nbsphinx >=0.8.5 4 | pandoc >=1.0 5 | pypandoc-binary 6 | docutils >=0.16 7 | sphinxcontrib-fulltoc >=1.0 8 | sphinxcontrib-mockautodoc 9 | sphinx-autodoc-typehints >=1.0 10 | sphinx-paramlinks >=0.5.1 11 | sphinx-togglebutton >=0.2 12 | sphinx-copybutton >=0.3 13 | # jinja2 >=3.0.0,<3.2.0 14 | 15 | pt-lightning-sphinx-theme @ https://github.com/Lightning-AI/lightning_sphinx_theme/archive/master.zip 16 | -------------------------------------------------------------------------------- /_requirements/extra.txt: -------------------------------------------------------------------------------- 1 | lightning >= 2.0.0 2 | numpy <2.0.0 ; platform_system == "Darwin" # compatibility fix for Torch 3 | joblib >= 1.0.0 4 | -------------------------------------------------------------------------------- /_requirements/test.txt: -------------------------------------------------------------------------------- 1 | coverage >=5.0 2 | pytest >=6.0 3 | pytest-cov 4 | pytest-mock 5 | 6 | pytorch-lightning >=2.0 7 | scikit-learn >=1.0 8 | huggingface-hub >=0.29.0 9 | tensorflow >=2.0 10 | -------------------------------------------------------------------------------- /_requirements/typing.txt: -------------------------------------------------------------------------------- 1 | mypy ==1.16.0 2 | -------------------------------------------------------------------------------- /docs/.build_docs.sh: -------------------------------------------------------------------------------- 1 | make clean 2 | make html --debug --jobs $(nproc) 3 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = -W 6 | SPHINXBUILD = python $(shell which sphinx-build) 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/copybutton.js: -------------------------------------------------------------------------------- 1 | /* Copied from the official Python docs: https://docs.python.org/3/_static/copybutton.js */ 2 | $(document).ready(function () { 3 | /* Add a [>>>] button on the top-right corner of code samples to hide 4 | * the >>> and ... prompts and the output and thus make the code 5 | * copyable. */ 6 | var div = $( 7 | ".highlight-python .highlight," + 8 | ".highlight-python3 .highlight," + 9 | ".highlight-pycon .highlight," + 10 | ".highlight-default .highlight", 11 | ); 12 | var pre = div.find("pre"); 13 | 14 | // get the styles from the current theme 15 | pre.parent().parent().css("position", "relative"); 16 | var hide_text = "Hide the prompts and output"; 17 | var show_text = "Show the prompts and output"; 18 | var border_width = pre.css("border-top-width"); 19 | var border_style = pre.css("border-top-style"); 20 | var border_color = pre.css("border-top-color"); 21 | var button_styles = { 22 | cursor: "pointer", 23 | position: "absolute", 24 | top: "0", 25 | right: "0", 26 | "border-color": border_color, 27 | "border-style": border_style, 28 | "border-width": border_width, 29 | color: border_color, 30 | "text-size": "75%", 31 | "font-family": "monospace", 32 | "padding-left": "0.2em", 33 | "padding-right": "0.2em", 34 | "border-radius": "0 3px 0 0", 35 | }; 36 | 37 | // create and add the button to all the code blocks that contain >>> 38 | div.each(function (index) { 39 | var jthis = $(this); 40 | if (jthis.find(".gp").length > 0) { 41 | var button = $('>>>'); 42 | button.css(button_styles); 43 | button.attr("title", hide_text); 44 | button.data("hidden", "false"); 45 | jthis.prepend(button); 46 | } 47 | // tracebacks (.gt) contain bare text elements that need to be 48 | // wrapped in a span to work with .nextUntil() (see later) 49 | jthis 50 | .find("pre:has(.gt)") 51 | .contents() 52 | .filter(function () { 53 | return this.nodeType == 3 && this.data.trim().length > 0; 54 | }) 55 | .wrap(""); 56 | }); 57 | 58 | // define the behavior of the button when it's clicked 59 | $(".copybutton").click(function (e) { 60 | e.preventDefault(); 61 | var button = $(this); 62 | if (button.data("hidden") === "false") { 63 | // hide the code output 64 | button.parent().find(".go, .gp, .gt").hide(); 65 | button.next("pre").find(".gt").nextUntil(".gp, .go").css("visibility", "hidden"); 66 | button.css("text-decoration", "line-through"); 67 | button.attr("title", show_text); 68 | button.data("hidden", "true"); 69 | } else { 70 | // show the code output 71 | button.parent().find(".go, .gp, .gt").show(); 72 | button.next("pre").find(".gt").nextUntil(".gp, .go").css("visibility", "visible"); 73 | button.css("text-decoration", "none"); 74 | button.attr("title", hide_text); 75 | button.data("hidden", "false"); 76 | } 77 | }); 78 | }); 79 | -------------------------------------------------------------------------------- /docs/source/_static/images/icon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /docs/source/_static/images/logo-large.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /docs/source/_static/images/logo-small.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /docs/source/_static/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/LitModels/7717ff0bc7e76745d9458305257ab19fd2346f5e/docs/source/_static/images/logo.png -------------------------------------------------------------------------------- /docs/source/_static/images/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /docs/source/_templates/theme_variables.jinja: -------------------------------------------------------------------------------- 1 | {%- set external_urls = { 2 | 'github': 'https://github.com/Lightning-AI/litModels', 3 | 'github_issues': 'https://github.com/Lightning-AI/litModels/issues', 4 | 'contributing': 'https://github.com/Lightning-AI/lightning/blob/master/CONTRIBUTING.md', 5 | 'governance': 'https://github.com/Lightning-AI/lightning/blob/master/governance.md', 6 | 'docs': 'https://lightning-ai.github.io/models/', 7 | 'twitter': 'https://twitter.com/LightningAI', 8 | 'discuss': 'https://discord.com/invite/tfXFetEZxv', 9 | 'tutorials': 'https://lightning.ai', 10 | 'previous_pytorch_versions': 'https://lightning-ai.github.io/models/', 11 | 'home': 'https://lightning-ai.github.io/models/', 12 | 'get_started': 'https://lightning.ai', 13 | 'features': 'https://lightning-ai.github.io/models/', 14 | 'blog': 'https://www.Lightning.ai/blog', 15 | 'resources': 'https://lightning.ai', 16 | 'support': 'https://lightning-ai.github.io/models/', 17 | } 18 | -%} 19 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # 2 | # Configuration file for the Sphinx documentation builder. 3 | # 4 | # This file does only contain a selection of the most common options. For a 5 | # full list see the documentation: 6 | # http://www.sphinx-doc.org/en/master/config 7 | import glob 8 | import inspect 9 | import os 10 | import re 11 | import sys 12 | from importlib.util import module_from_spec, spec_from_file_location 13 | 14 | import pt_lightning_sphinx_theme 15 | import pypandoc 16 | 17 | _PATH_HERE = os.path.abspath(os.path.dirname(__file__)) 18 | _PATH_ROOT = os.path.realpath(os.path.join(_PATH_HERE, "..", "..")) 19 | _PATH_SOURCE = os.path.join(_PATH_ROOT, "src") 20 | sys.path.insert(0, os.path.abspath(_PATH_ROOT)) 21 | 22 | SPHINX_MOCK_REQUIREMENTS = int(os.environ.get("SPHINX_MOCK_REQUIREMENTS", True)) 23 | 24 | # alternative https://stackoverflow.com/a/67692/4521646 25 | spec = spec_from_file_location( 26 | "litmodels/__about__.py", os.path.join(_PATH_SOURCE, "litmodels", "__about__.py") 27 | ) 28 | about = module_from_spec(spec) 29 | spec.loader.exec_module(about) 30 | 31 | # -- Project information ----------------------------------------------------- 32 | 33 | # this name shall match the project name in Github as it is used for linking to code 34 | project = "models" 35 | copyright = about.__copyright__ 36 | author = about.__author__ 37 | 38 | # The short X.Y version 39 | version = about.__version__ 40 | # The full version, including alpha/beta/rc tags 41 | release = about.__version__ 42 | 43 | # Options for the linkcode extension 44 | # ---------------------------------- 45 | github_user = "Lightning-AI" 46 | github_repo = project 47 | 48 | # -- Project documents ------------------------------------------------------- 49 | 50 | 51 | def _transform_changelog(path_in: str, path_out: str) -> None: 52 | with open(path_in) as fp: 53 | chlog_lines = fp.readlines() 54 | # enrich short subsub-titles to be unique 55 | chlog_ver = "" 56 | for i, ln in enumerate(chlog_lines): 57 | if ln.startswith("## "): 58 | chlog_ver = ln[2:].split("-")[0].strip() 59 | elif ln.startswith("### "): 60 | ln = ln.replace("###", f"### {chlog_ver} -") 61 | chlog_lines[i] = ln 62 | with open(path_out, "w") as fp: 63 | fp.writelines(chlog_lines) 64 | 65 | 66 | def _convert_markdown(path_in: str, path_out: str) -> None: 67 | with open(path_in) as fp: 68 | readme = fp.read() 69 | # TODO: temp fix removing SVG badges and GIF, because they are automatically 100% wide 70 | readme = re.sub(r"(\[!\[.*\))", "", readme) 71 | readme = re.sub(r"(!\[.*.gif\))", "", readme) 72 | folder_names = ( 73 | os.path.basename(p) 74 | for p in glob.glob(os.path.join(_PATH_ROOT, "*")) 75 | if os.path.isdir(p) 76 | ) 77 | for dir_name in folder_names: 78 | readme = readme.replace( 79 | "](%s/" % dir_name, "](%s/" % os.path.join(_PATH_ROOT, dir_name) 80 | ) 81 | readme = pypandoc.convert_text(readme, format="markdown", to="rst") 82 | with open(path_out, "w") as fp: 83 | fp.write(readme) 84 | 85 | 86 | # export the READme 87 | _convert_markdown(os.path.join(_PATH_ROOT, "README.md"), "readme.rst") 88 | 89 | # -- General configuration --------------------------------------------------- 90 | 91 | # If your documentation needs a minimal Sphinx version, state it here. 92 | 93 | needs_sphinx = "6.2" 94 | 95 | # Add any Sphinx extension module names here, as strings. They can be 96 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 97 | # ones. 98 | extensions = [ 99 | "sphinx.ext.autodoc", 100 | # 'sphinxcontrib.mockautodoc', # raises error: directive 'automodule' is already registered ... 101 | # 'sphinxcontrib.fulltoc', # breaks pytorch-theme with unexpected kw argument 'titles_only' 102 | "sphinx.ext.doctest", 103 | "sphinx.ext.intersphinx", 104 | "sphinx.ext.todo", 105 | "sphinx.ext.coverage", 106 | "sphinx.ext.linkcode", 107 | "sphinx.ext.autosummary", 108 | "sphinx.ext.napoleon", 109 | "sphinx.ext.imgmath", 110 | "myst_parser", 111 | "sphinx.ext.autosectionlabel", 112 | "nbsphinx", 113 | "sphinx_autodoc_typehints", 114 | "sphinx_paramlinks", 115 | "sphinx.ext.githubpages", 116 | "pt_lightning_sphinx_theme.extensions.lightning", 117 | ] 118 | 119 | # Add any paths that contain templates here, relative to this directory. 120 | templates_path = ["_templates"] 121 | 122 | myst_update_mathjax = False 123 | 124 | # https://berkeley-stat159-f17.github.io/stat159-f17/lectures/14-sphinx..html#conf.py-(cont.) 125 | # https://stackoverflow.com/questions/38526888/embed-ipython-notebook-in-sphinx-document 126 | # I execute the notebooks manually in advance. If notebooks test the code, 127 | # they should be run at build time. 128 | nbsphinx_execute = "never" 129 | nbsphinx_allow_errors = True 130 | nbsphinx_requirejs_path = "" 131 | 132 | # The suffix(es) of source filenames. 133 | # You can specify multiple suffix as a list of string: 134 | # 135 | # source_suffix = ['.rst', '.md'] 136 | # source_suffix = ['.rst', '.md', '.ipynb'] 137 | source_suffix = { 138 | ".rst": "restructuredtext", 139 | ".txt": "markdown", 140 | ".md": "markdown", 141 | ".ipynb": "nbsphinx", 142 | } 143 | 144 | # The master toctree document. 145 | master_doc = "index" 146 | 147 | # The language for content autogenerated by Sphinx. Refer to documentation 148 | # for a list of supported languages. 149 | # 150 | # This is also used if you do content translation via gettext catalogs. 151 | # Usually you set "language" from the command line for these cases. 152 | language = "en" 153 | 154 | # List of patterns, relative to source directory, that match files and 155 | # directories to ignore when looking for source files. 156 | # This pattern also affects html_static_path and html_extra_path. 157 | exclude_patterns = [ 158 | "PULL_REQUEST_TEMPLATE.md", 159 | ] 160 | 161 | # The name of the Pygments (syntax highlighting) style to use. 162 | pygments_style = None 163 | 164 | # -- Options for HTML output ------------------------------------------------- 165 | 166 | # The theme to use for HTML and HTML Help pages. See the documentation for 167 | # a list of builtin themes. 168 | # 169 | html_theme = "pt_lightning_sphinx_theme" 170 | html_theme_path = [pt_lightning_sphinx_theme.get_html_theme_path()] 171 | 172 | # Theme options are theme-specific and customize the look and feel of a theme 173 | # further. For a list of options available for each theme, see the 174 | # documentation. 175 | 176 | html_theme_options = { 177 | "pytorch_project": about.__homepage__, 178 | "canonical_url": about.__homepage__, 179 | "collapse_navigation": False, 180 | "display_version": True, 181 | "logo_only": False, 182 | } 183 | 184 | html_favicon = "_static/images/icon.svg" 185 | 186 | # Add any paths that contain custom static files (such as style sheets) here, 187 | # relative to this directory. They are copied after the builtin static files, 188 | # so a file named "default.css" will overwrite the builtin "default.css". 189 | html_static_path = ["_templates", "_static"] 190 | 191 | # Custom sidebar templates, must be a dictionary that maps document names 192 | # to template names. 193 | # 194 | # The default sidebars (for documents that don't match any pattern) are 195 | # defined by theme itself. Builtin themes are using these templates by 196 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 197 | # 'searchbox.html']``. 198 | # 199 | # html_sidebars = {} 200 | 201 | # -- Options for HTMLHelp output --------------------------------------------- 202 | 203 | # Output file base name for HTML help builder. 204 | htmlhelp_basename = project + "-doc" 205 | 206 | # -- Options for LaTeX output ------------------------------------------------ 207 | 208 | latex_elements = { 209 | # The paper size ('letterpaper' or 'a4paper'). 210 | # 'papersize': 'letterpaper', 211 | # The font size ('10pt', '11pt' or '12pt'). 212 | # 'pointsize': '10pt', 213 | # Additional stuff for the LaTeX preamble. 214 | # 'preamble': '', 215 | # Latex figure (float) alignment 216 | "figure_align": "htbp", 217 | } 218 | 219 | # Grouping the document tree into LaTeX files. List of tuples 220 | # (source start file, target name, title, 221 | # author, documentclass [howto, manual, or own class]). 222 | latex_documents = [ 223 | (master_doc, project + ".tex", project + " Documentation", author, "manual"), 224 | ] 225 | 226 | # -- Options for manual page output ------------------------------------------ 227 | 228 | # One entry per manual page. List of tuples 229 | # (source start file, name, description, authors, manual section). 230 | man_pages = [(master_doc, project, project + " Documentation", [author], 1)] 231 | 232 | # -- Options for Texinfo output ---------------------------------------------- 233 | 234 | # Grouping the document tree into Texinfo files. List of tuples 235 | # (source start file, target name, title, author, 236 | # dir menu entry, description, category) 237 | texinfo_documents = [ 238 | ( 239 | master_doc, 240 | project, 241 | project + " Documentation", 242 | author, 243 | project, 244 | about.__docs__, 245 | "Miscellaneous", 246 | ), 247 | ] 248 | 249 | # -- Options for Epub output ------------------------------------------------- 250 | 251 | # Bibliographic Dublin Core info. 252 | epub_title = project 253 | 254 | # The unique identifier of the text. This can be a ISBN number 255 | # or the project homepage. 256 | # 257 | # epub_identifier = '' 258 | 259 | # A unique identification for the text. 260 | # 261 | # epub_uid = '' 262 | 263 | # A list of files that should not be packed into the epub file. 264 | epub_exclude_files = ["search.html"] 265 | 266 | # -- Extension configuration ------------------------------------------------- 267 | 268 | # -- Options for intersphinx extension --------------------------------------- 269 | 270 | # Example configuration for intersphinx: refer to the Python standard library. 271 | intersphinx_mapping = { 272 | "python": ("https://docs.python.org/3", None), 273 | "torch": ("https://pytorch.org/docs/stable/", None), 274 | "numpy": ("https://numpy.org/doc/stable/", None), 275 | } 276 | 277 | # -- Options for to-do extension ---------------------------------------------- 278 | 279 | # If true, `todo` and `todoList` produce output, else they produce nothing. 280 | todo_include_todos = True 281 | 282 | 283 | def setup(app): 284 | # this is for hiding doctest decoration, 285 | # see: http://z4r.github.io/python/2011/12/02/hides-the-prompts-and-output/ 286 | app.add_js_file("copybutton.js") 287 | 288 | 289 | # Ignoring Third-party packages 290 | # https://stackoverflow.com/questions/15889621/sphinx-how-to-exclude-imports-in-automodule 291 | def _package_list_from_file(file): 292 | list_pkgs = [] 293 | with open(file) as fp: 294 | lines = fp.readlines() 295 | for ln in lines: 296 | found = [ln.index(ch) for ch in list(",=<>#") if ch in ln] 297 | pkg = ln[: min(found)] if found else ln 298 | if pkg.rstrip(): 299 | list_pkgs.append(pkg.rstrip()) 300 | return list_pkgs 301 | 302 | 303 | # define mapping from PyPI names to python imports 304 | PACKAGE_MAPPING = { 305 | "PyYAML": "yaml", 306 | } 307 | MOCK_PACKAGES = [] 308 | if SPHINX_MOCK_REQUIREMENTS: 309 | # mock also base packages when we are on RTD since we don't install them there 310 | MOCK_PACKAGES += _package_list_from_file( 311 | os.path.join(_PATH_ROOT, "requirements.txt") 312 | ) 313 | MOCK_PACKAGES = [PACKAGE_MAPPING.get(pkg, pkg) for pkg in MOCK_PACKAGES] 314 | 315 | autodoc_mock_imports = MOCK_PACKAGES 316 | 317 | 318 | # Resolve function 319 | # This function is used to populate the (source) links in the API 320 | def linkcode_resolve(domain, info): 321 | def find_source(): 322 | # try to find the file and line number, based on code from numpy: 323 | # https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L286 324 | obj = sys.modules[info["module"]] 325 | for part in info["fullname"].split("."): 326 | obj = getattr(obj, part) 327 | fname = inspect.getsourcefile(obj) 328 | # https://github.com/rtfd/readthedocs.org/issues/5735 329 | if any(s in fname for s in ("readthedocs", "rtfd", "checkouts")): 330 | # /home/docs/checkouts/readthedocs.org/user_builds/pytorch_lightning/checkouts/ 331 | # devel/pytorch_lightning/utilities/cls_experiment.py#L26-L176 332 | path_top = os.path.abspath(os.path.join("..", "..", "..")) 333 | fname = os.path.relpath(fname, start=path_top) 334 | else: 335 | # Local build, imitate master 336 | fname = "master/" + os.path.relpath(fname, start=os.path.abspath("..")) 337 | source, lineno = inspect.getsourcelines(obj) 338 | return fname, lineno, lineno + len(source) - 1 339 | 340 | if domain != "py" or not info["module"]: 341 | return None 342 | try: 343 | filename = "%s#L%d-L%d" % find_source() 344 | except Exception: 345 | filename = info["module"].replace(".", "/") + ".py" 346 | # import subprocess 347 | # tag = subprocess.Popen(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE, 348 | # universal_newlines=True).communicate()[0][:-1] 349 | branch = filename.split("/")[0] 350 | # do mapping from latest tags to master 351 | branch = {"latest": "main", "stable": "main"}.get(branch, branch) 352 | filename = "/".join([branch] + filename.split("/")[1:]) 353 | return f"https://github.com/{github_user}/{github_repo}/blob/{filename}" 354 | 355 | 356 | autosummary_generate = True 357 | 358 | autodoc_member_order = "groupwise" 359 | autoclass_content = "both" 360 | # the options are fixed and will be soon in release, 361 | # see https://github.com/sphinx-doc/sphinx/issues/5459 362 | autodoc_default_options = { 363 | "members": None, 364 | "methods": None, 365 | # 'attributes': None, 366 | "special-members": "__call__", 367 | "exclude-members": "_abc_impl", 368 | "show-inheritance": True, 369 | "private-members": True, 370 | "noindex": True, 371 | } 372 | 373 | # Sphinx will add “permalinks” for each heading and description environment as paragraph signs that 374 | # become visible when the mouse hovers over them. 375 | # This value determines the text for the permalink; it defaults to "¶". Set it to None or the empty 376 | # string to disable permalinks. 377 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#confval-html_add_permalinks 378 | # html_add_permalinks = "¶" 379 | 380 | # True to prefix each section label with the name of the document it is in, followed by a colon. 381 | # For example, index:Introduction for a section called Introduction that appears in document index.rst. 382 | # Useful for avoiding ambiguity when the same section heading appears in different documents. 383 | # http://www.sphinx-doc.org/en/master/usage/extensions/autosectionlabel.html 384 | autosectionlabel_prefix_document = True 385 | 386 | # only run doctests marked with a ".. doctest::" directive 387 | doctest_test_doctest_blocks = "" 388 | doctest_global_setup = """ 389 | """ 390 | coverage_skip_undoc_in_source = True 391 | 392 | linkcheck_ignore = [ 393 | # ignore the following URLs 394 | "https://github.com/gridai/lit-logger", 395 | ] 396 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Lightning-AI-Sandbox documentation master file, created by 2 | sphinx-quickstart on Wed Mar 25 21:34:07 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Lightning-Models 7 | ================ 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | :name: start 12 | :caption: Start here 13 | 14 | readme 15 | 16 | 17 | Indices and tables 18 | ================== 19 | 20 | * :ref:`genindex` 21 | * :ref:`modindex` 22 | * :ref:`search` 23 | -------------------------------------------------------------------------------- /examples/demo-tensorflow-keras.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | 3 | from litmodels import load_model, save_model 4 | 5 | if __name__ == "__main__": 6 | # Define the model 7 | model = keras.Sequential([ 8 | keras.layers.Dense(10, input_shape=(784,), name="dense_1"), 9 | keras.layers.Dense(10, name="dense_2"), 10 | ]) 11 | 12 | # Compile the model 13 | model.compile(optimizer="adam", loss="categorical_crossentropy") 14 | 15 | # Save the model 16 | save_model("lightning-ai/jirka/sample-tf-keras-model", model=model) 17 | 18 | # Load the model 19 | model_ = load_model("lightning-ai/jirka/sample-tf-keras-model", download_dir="./my-model") 20 | -------------------------------------------------------------------------------- /examples/demo-upload-download.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lightning.pytorch.demos.boring_classes import BoringModel 3 | 4 | import litmodels 5 | 6 | if __name__ == "__main__": 7 | # Define your model 8 | model = BoringModel() 9 | 10 | # Save the model's state dictionary 11 | torch.save(model.state_dict(), "./boring-checkpoint.pt") 12 | 13 | # Upload the model checkpoint 14 | litmodels.upload_model( 15 | "./boring-checkpoint.pt", 16 | "lightning-ai/jirka/lit-boring-model", 17 | ) 18 | 19 | # Download the model checkpoint 20 | model_path = litmodels.download_model("lightning-ai/jirka/lit-boring-model", download_dir="./my-models") 21 | print(f"Model downloaded to {model_path}") 22 | 23 | # Load the model checkpoint 24 | loaded_model = BoringModel() 25 | loaded_model.load_state_dict(torch.load("./boring-checkpoint.pt")) 26 | print(loaded_model) 27 | -------------------------------------------------------------------------------- /examples/resume-lightning-training.py: -------------------------------------------------------------------------------- 1 | """ 2 | This example demonstrates how to resume training of a model using the `download_model` function. 3 | """ 4 | 5 | import os 6 | 7 | from lightning import Trainer 8 | from lightning.pytorch.demos.boring_classes import BoringModel 9 | 10 | from litmodels import download_model 11 | 12 | # Define the model name - this should be unique to your model 13 | # The format is //: 14 | MY_MODEL_NAME = "lightning-ai/jirka/lit-boring-callback:latest" 15 | 16 | 17 | if __name__ == "__main__": 18 | model_files = download_model(name=MY_MODEL_NAME, download_dir="my_models") 19 | model_path = os.path.join("my_models", model_files[0]) 20 | print(f"model: {model_path}") 21 | 22 | trainer = Trainer(max_epochs=4) 23 | trainer.fit( 24 | BoringModel(), 25 | ckpt_path=model_path, 26 | ) 27 | -------------------------------------------------------------------------------- /examples/train-model-and-simple-save.py: -------------------------------------------------------------------------------- 1 | """ 2 | This example demonstrates how to train a model and upload it to the cloud using the `upload_model` function. 3 | """ 4 | 5 | from lightning import Trainer 6 | from lightning.pytorch.demos.boring_classes import BoringModel 7 | 8 | from litmodels import upload_model 9 | 10 | # Define the model name - this should be unique to your model 11 | # The format is // 12 | MY_MODEL_NAME = "lightning-ai/jirka/lit-boring-simple" 13 | 14 | 15 | if __name__ == "__main__": 16 | trainer = Trainer(max_epochs=2) 17 | trainer.fit(BoringModel()) 18 | checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path") 19 | print(f"best: {checkpoint_path}") 20 | upload_model(model=checkpoint_path, name=MY_MODEL_NAME) 21 | -------------------------------------------------------------------------------- /examples/train-model-with-lightning-callback.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a model with a Lightning callback that uploads the best model to the cloud after each epoch. 3 | """ 4 | 5 | from lightning import Trainer 6 | from lightning.pytorch.demos.boring_classes import BoringModel 7 | 8 | from litmodels.integrations import LightningModelCheckpoint 9 | 10 | # Define the model name - this should be unique to your model 11 | # The format is // 12 | MY_MODEL_NAME = "lightning-ai/jirka/lit-boring-callback" 13 | 14 | 15 | if __name__ == "__main__": 16 | trainer = Trainer( 17 | max_epochs=2, 18 | callbacks=LightningModelCheckpoint(model_registry=MY_MODEL_NAME), 19 | ) 20 | trainer.fit(BoringModel()) 21 | -------------------------------------------------------------------------------- /examples/train-model-with-lightning-logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Enhanced Logging with LightningLogger 3 | 4 | Integrate with [LitLogger](https://github.com/gridai/lit-logger) to automatically log your model checkpoints 5 | and training metrics to cloud storage. 6 | Though the example utilizes PyTorch Lightning, this integration concept works across various model training frameworks. 7 | 8 | """ 9 | 10 | from lightning import Trainer 11 | from lightning.pytorch.demos.boring_classes import BoringModel 12 | from litlogger import LightningLogger 13 | 14 | 15 | class DemoModel(BoringModel): 16 | def training_step(self, batch, batch_idx): 17 | output = super().training_step(batch, batch_idx) 18 | self.log("train_loss", output["loss"]) 19 | return output 20 | 21 | 22 | if __name__ == "__main__": 23 | # configure the logger 24 | lit_logger = LightningLogger(log_model=True) 25 | 26 | # pass logger to the Trainer 27 | trainer = Trainer(max_epochs=5, logger=lit_logger) 28 | 29 | # train the model 30 | trainer.fit(model=DemoModel()) 31 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_file = "LICENSE" 3 | description-file = "README.md" 4 | 5 | [build-system] 6 | requires = [ 7 | "setuptools", 8 | "wheel", 9 | ] 10 | 11 | 12 | [tool.check-manifest] 13 | ignore = [ 14 | "*.yml", 15 | ".github", 16 | ".github/*" 17 | ] 18 | 19 | 20 | [tool.pytest.ini_options] 21 | norecursedirs = [ 22 | ".git", 23 | ".github", 24 | "dist", 25 | "build", 26 | "docs", 27 | ] 28 | addopts = [ 29 | "--strict-markers", 30 | "--doctest-modules", 31 | "--color=yes", 32 | "--disable-pytest-warnings", 33 | ] 34 | markers = [ 35 | "cloud:Run the cloud tests for example", 36 | ] 37 | filterwarnings = [ 38 | "error::FutureWarning", 39 | ] 40 | xfail_strict = true 41 | junit_duration_report = "call" 42 | 43 | [tool.coverage.report] 44 | exclude_lines = [ 45 | "pragma: no cover", 46 | "pass", 47 | ] 48 | 49 | [tool.codespell] 50 | #skip = '*.py' 51 | quiet-level = 3 52 | # comma separated list of words; waiting for: 53 | # https://github.com/codespell-project/codespell/issues/2839#issuecomment-1731601603 54 | # also adding links until they ignored by its: nature 55 | # https://github.com/codespell-project/codespell/issues/2243#issuecomment-1732019960 56 | #ignore-words-list = "" 57 | 58 | 59 | [tool.docformatter] 60 | recursive = true 61 | wrap-summaries = 120 62 | wrap-descriptions = 120 63 | blank = true 64 | 65 | 66 | [tool.mypy] 67 | files = [ 68 | "src", 69 | ] 70 | install_types = true 71 | non_interactive = true 72 | disallow_untyped_defs = true 73 | ignore_missing_imports = true 74 | show_error_codes = true 75 | warn_redundant_casts = true 76 | warn_unused_configs = true 77 | warn_unused_ignores = true 78 | allow_redefinition = true 79 | # disable this rule as the Trainer attributes are defined in the connectors, not in its __init__ 80 | disable_error_code = "attr-defined" 81 | # style choices 82 | warn_no_return = false 83 | 84 | 85 | [tool.ruff] 86 | line-length = 120 87 | target-version = 'py38' 88 | exclude = [ 89 | "build", 90 | "dist", 91 | "docs" 92 | ] 93 | 94 | [tool.ruff.format] 95 | preview = true 96 | 97 | 98 | [tool.ruff.lint] 99 | select = [ 100 | "E", "W", # see: https://pypi.org/project/pycodestyle 101 | "F", # see: https://pypi.org/project/pyflakes 102 | "D", # see: https://pypi.org/project/pydocstyle 103 | "N", # see: https://pypi.org/project/pep8-naming 104 | "RUF018", # see: https://docs.astral.sh/ruff/rules/assignment-in-assert 105 | "UP", # see: https://docs.astral.sh/ruff/rules/#pyupgrade-up 106 | "I", # implementation for isort 107 | ] 108 | extend-select = [ 109 | "C4", # see: https://pypi.org/project/flake8-comprehensions 110 | "PT", # see: https://pypi.org/project/flake8-pytest-style 111 | "RET", # see: https://pypi.org/project/flake8-return 112 | "SIM", # see: https://pypi.org/project/flake8-simplify 113 | ] 114 | ignore = [ 115 | "E731", # Do not assign a lambda expression, use a def 116 | "D100", # Missing docstring in public module 117 | ] 118 | 119 | [tool.ruff.lint.per-file-ignores] 120 | "setup.py" = ["D100", "SIM115"] 121 | "__about__.py" = ["D100"] 122 | "__init__.py" = ["D100", "E402"] 123 | "examples/**" = ["D"] # todo 124 | "tests/**" = ["D"] 125 | 126 | [tool.ruff.lint.pydocstyle] 127 | # Use Google-style docstrings. 128 | convention = "google" 129 | 130 | #[tool.ruff.pycodestyle] 131 | #ignore-overlong-task-comments = true 132 | 133 | [tool.ruff.lint.mccabe] 134 | # Unlike Flake8, default to a complexity level of 10. 135 | max-complexity = 10 136 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | lightning-sdk >=0.2.11 2 | lightning-utilities 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import glob 3 | import os 4 | from importlib.util import module_from_spec, spec_from_file_location 5 | from pathlib import Path 6 | 7 | from pkg_resources import parse_requirements 8 | from setuptools import find_packages, setup 9 | 10 | _PATH_ROOT = os.path.dirname(__file__) 11 | _PATH_SOURCE = os.path.join(_PATH_ROOT, "src") 12 | _PATH_REQUIRES = os.path.join(_PATH_ROOT, "_requirements") 13 | 14 | 15 | def _load_py_module(fname, pkg="litmodels"): 16 | spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_SOURCE, pkg, fname)) 17 | py = module_from_spec(spec) 18 | spec.loader.exec_module(py) 19 | return py 20 | 21 | 22 | def _load_requirements(path_dir: str = _PATH_ROOT, file_name: str = "requirements.txt") -> list: 23 | reqs = parse_requirements(open(os.path.join(path_dir, file_name)).readlines()) 24 | return list(map(str, reqs)) 25 | 26 | 27 | about = _load_py_module("__about__.py") 28 | with open(os.path.join(_PATH_ROOT, "README.md"), encoding="utf-8") as fopen: 29 | readme = fopen.read() 30 | 31 | 32 | def _prepare_extras(requirements_dir: str = _PATH_REQUIRES, skip_files: tuple = ("devel.txt",)) -> dict: 33 | # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras 34 | # Define package extras. These are only installed if you specify them. 35 | # From remote, use like `pip install pytorch-lightning[dev, docs]` 36 | # From local copy of repo, use like `pip install ".[dev, docs]"` 37 | req_files = [Path(p) for p in glob.glob(os.path.join(requirements_dir, "*.txt"))] 38 | extras = { 39 | p.stem: _load_requirements(file_name=p.name, path_dir=str(p.parent)) 40 | for p in req_files 41 | # ignore some development specific requirements 42 | if p.name not in skip_files and not p.name.startswith("_") 43 | } 44 | # todo: eventually add some custom aggregations such as `develop` 45 | extras = {name: sorted(set(reqs)) for name, reqs in extras.items()} 46 | print("The extras are: ", extras) 47 | return extras 48 | 49 | 50 | # https://packaging.python.org/discussions/install-requires-vs-requirements / 51 | # keep the meta-data here for simplicity in reading this file... it's not obvious 52 | # what happens and to non-engineers they won't know to look in init ... 53 | # the goal of the project is simplicity for researchers, don't want to add too much 54 | # engineer specific practices 55 | setup( 56 | name="litmodels", 57 | version=about.__version__, 58 | description=about.__docs__, 59 | author=about.__author__, 60 | author_email=about.__author_email__, 61 | url=about.__homepage__, 62 | download_url="https://github.com/Lightning-AI/litModels", 63 | license=about.__license__, 64 | packages=find_packages(where="src"), 65 | package_dir={"": "src"}, 66 | long_description=readme, 67 | long_description_content_type="text/markdown", 68 | include_package_data=True, 69 | zip_safe=False, 70 | keywords=["deep learning", "pytorch", "AI"], 71 | python_requires=">=3.8", 72 | setup_requires=["wheel"], 73 | install_requires=_load_requirements(), 74 | extras_require=_prepare_extras(), 75 | project_urls={ 76 | "Bug Tracker": "https://github.com/Lightning-AI/litModels/issues", 77 | "Documentation": "https://lightning-ai.github.io/litModels/", 78 | "Source Code": "https://github.com/Lightning-AI/litModels", 79 | }, 80 | classifiers=[ 81 | "Environment :: Console", 82 | "Natural Language :: English", 83 | # How mature is this project? Common values are 84 | # 3 - Alpha, 4 - Beta, 5 - Production/Stable 85 | "Development Status :: 3 - Alpha", 86 | # Indicate who your project is intended for 87 | "Intended Audience :: Developers", 88 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 89 | "Topic :: Scientific/Engineering :: Information Analysis", 90 | # Pick your license as you wish 91 | "License :: OSI Approved :: Apache Software License", 92 | "Operating System :: OS Independent", 93 | # Specify the Python versions you support here. In particular, ensure 94 | # that you indicate whether you support Python 2, Python 3 or both. 95 | "Programming Language :: Python :: 3", 96 | "Programming Language :: Python :: 3.8", 97 | "Programming Language :: Python :: 3.9", 98 | "Programming Language :: Python :: 3.10", 99 | "Programming Language :: Python :: 3.11", 100 | "Programming Language :: Python :: 3.12", 101 | "Programming Language :: Python :: 3.13", 102 | ], 103 | ) 104 | -------------------------------------------------------------------------------- /src/litmodels/__about__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.8" 2 | __author__ = "Lightning-AI et al." 3 | __author_email__ = "community@lightning.ai" 4 | __license__ = "Apache-2.0" 5 | __copyright__ = f"Copyright (c) 2024-2025, {__author__}." 6 | __homepage__ = "https://github.com/Lightning-AI/litModels" 7 | __docs__ = "Lightning AI Model hub." 8 | 9 | __all__ = [ 10 | "__author__", 11 | "__author_email__", 12 | "__copyright__", 13 | "__docs__", 14 | "__homepage__", 15 | "__license__", 16 | "__version__", 17 | ] 18 | -------------------------------------------------------------------------------- /src/litmodels/__init__.py: -------------------------------------------------------------------------------- 1 | """Root package info.""" 2 | 3 | import os 4 | 5 | from litmodels.__about__ import * # noqa: F401, F403 6 | 7 | _PACKAGE_ROOT = os.path.dirname(__file__) 8 | _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) 9 | 10 | from litmodels.io import download_model, load_model, save_model, upload_model, upload_model_files # noqa: F401 11 | 12 | __all__ = ["download_model", "upload_model", "load_model", "save_model"] 13 | -------------------------------------------------------------------------------- /src/litmodels/demos/__init__.py: -------------------------------------------------------------------------------- 1 | """Define a demos model for examples and testing purposes.""" 2 | 3 | from lightning_utilities import module_available 4 | 5 | __all__ = [] 6 | 7 | if module_available("lightning"): 8 | from lightning.pytorch.demos.boring_classes import BoringModel, DemoModel 9 | 10 | __all__ += ["BoringModel", "DemoModel"] 11 | elif module_available("pytorch_lightning"): 12 | from pytorch_lightning.demos.boring_classes import BoringModel, DemoModel 13 | 14 | __all__ += ["BoringModel", "DemoModel"] 15 | -------------------------------------------------------------------------------- /src/litmodels/integrations/__init__.py: -------------------------------------------------------------------------------- 1 | """Integrations with training frameworks like PyTorch Lightning, TensorFlow, and others.""" 2 | 3 | from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE 4 | 5 | __all__ = [] 6 | 7 | if _LIGHTNING_AVAILABLE: 8 | from litmodels.integrations.checkpoints import LightningModelCheckpoint 9 | 10 | __all__ += ["LightningModelCheckpoint"] 11 | 12 | if _PYTORCHLIGHTNING_AVAILABLE: 13 | from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint 14 | 15 | __all__ += ["PytorchLightningModelCheckpoint"] 16 | -------------------------------------------------------------------------------- /src/litmodels/integrations/checkpoints.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import os.path 3 | import queue 4 | import threading 5 | from abc import ABC 6 | from datetime import datetime 7 | from functools import lru_cache 8 | from pathlib import Path 9 | from typing import TYPE_CHECKING, Any, Optional, Union 10 | 11 | from lightning_sdk.lightning_cloud.login import Auth 12 | from lightning_sdk.utils.resolve import _resolve_teamspace 13 | from lightning_utilities import StrEnum 14 | from lightning_utilities.core.rank_zero import rank_zero_debug, rank_zero_only, rank_zero_warn 15 | 16 | from litmodels import upload_model 17 | from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE 18 | from litmodels.io.cloud import _list_available_teamspaces, delete_model_version 19 | 20 | if _LIGHTNING_AVAILABLE: 21 | from lightning.pytorch.callbacks import ModelCheckpoint as _LightningModelCheckpoint 22 | 23 | 24 | if _PYTORCHLIGHTNING_AVAILABLE: 25 | from pytorch_lightning.callbacks import ModelCheckpoint as _PytorchLightningModelCheckpoint 26 | 27 | 28 | if TYPE_CHECKING: 29 | if _LIGHTNING_AVAILABLE: 30 | import lightning.pytorch as pl 31 | if _PYTORCHLIGHTNING_AVAILABLE: 32 | import pytorch_lightning as pl 33 | 34 | 35 | # Create a singleton upload manager 36 | @lru_cache(maxsize=None) 37 | def get_model_manager() -> "ModelManager": 38 | """Get or create the singleton upload manager.""" 39 | return ModelManager() 40 | 41 | 42 | # enumerate the possible actions 43 | class Action(StrEnum): 44 | """Enumeration of possible actions for the ModelManager.""" 45 | 46 | UPLOAD = "upload" 47 | REMOVE = "remove" 48 | 49 | 50 | class RemoveType(StrEnum): 51 | """Enumeration of possible remove types for the ModelManager.""" 52 | 53 | LOCAL = "local" 54 | CLOUD = "cloud" 55 | 56 | 57 | class ModelManager: 58 | """Manages uploads and removals with a single queue but separate counters.""" 59 | 60 | task_queue: queue.Queue 61 | 62 | def __init__(self) -> None: 63 | """Initialize the ModelManager with a task queue and counters.""" 64 | self.task_queue = queue.Queue() 65 | self.upload_count = 0 66 | self.remove_count = 0 67 | self._worker = threading.Thread(target=self._worker_loop, daemon=True) 68 | self._worker.start() 69 | 70 | def __getstate__(self) -> dict: 71 | """Get the state of the ModelManager for pickling.""" 72 | state = self.__dict__.copy() 73 | del state["task_queue"] 74 | del state["_worker"] 75 | return state 76 | 77 | def __setstate__(self, state: dict) -> None: 78 | """Set the state of the ModelManager after unpickling.""" 79 | self.__dict__.update(state) 80 | import queue 81 | import threading 82 | 83 | self.task_queue = queue.Queue() 84 | self._worker = threading.Thread(target=self._worker_loop, daemon=True) 85 | self._worker.start() 86 | 87 | def _worker_loop(self) -> None: 88 | while True: 89 | task = self.task_queue.get() 90 | if task is None: 91 | self.task_queue.task_done() 92 | break 93 | action, detail = task 94 | if action == Action.UPLOAD: 95 | registry_name, filepath, metadata = detail 96 | try: 97 | upload_model(name=registry_name, model=filepath, metadata=metadata) 98 | rank_zero_debug(f"Finished uploading: {filepath}") 99 | except Exception as ex: 100 | rank_zero_warn(f"Upload failed {filepath}: {ex}") 101 | finally: 102 | self.upload_count -= 1 103 | elif action == Action.REMOVE: 104 | filepath, trainer, registry_name = detail 105 | try: 106 | if registry_name: 107 | rank_zero_debug(f"Removing from cloud: {filepath}") 108 | # Remove from the cloud 109 | version = os.path.splitext(os.path.basename(filepath))[0] 110 | delete_model_version(name=registry_name, version=version) 111 | if trainer: 112 | rank_zero_debug(f"Removed local file: {filepath}") 113 | trainer.strategy.remove_checkpoint(filepath) 114 | except Exception as ex: 115 | rank_zero_warn(f"Removal failed {filepath}: {ex}") 116 | finally: 117 | self.remove_count -= 1 118 | else: 119 | rank_zero_warn(f"Unknown task: {task}") 120 | self.task_queue.task_done() 121 | 122 | def queue_upload(self, registry_name: str, filepath: Union[str, Path], metadata: Optional[dict] = None) -> None: 123 | """Queue an upload task.""" 124 | self.upload_count += 1 125 | self.task_queue.put((Action.UPLOAD, (registry_name, filepath, metadata))) 126 | rank_zero_debug(f"Queued upload: {filepath} (pending uploads: {self.upload_count})") 127 | 128 | def queue_remove( 129 | self, filepath: Union[str, Path], trainer: Optional["pl.Trainer"] = None, registry_name: Optional[str] = None 130 | ) -> None: 131 | """Queue a removal task.""" 132 | self.remove_count += 1 133 | self.task_queue.put((Action.REMOVE, (filepath, trainer, registry_name))) 134 | rank_zero_debug(f"Queued removal: {filepath} (pending removals: {self.remove_count})") 135 | 136 | def shutdown(self) -> None: 137 | """Shut down the manager and wait for all tasks to complete.""" 138 | self.task_queue.put(None) 139 | self.task_queue.join() 140 | rank_zero_debug("Manager shut down.") 141 | 142 | 143 | # Base class to be inherited 144 | class LitModelCheckpointMixin(ABC): 145 | """Mixin class for LitModel checkpoint functionality.""" 146 | 147 | _datetime_stamp: str 148 | model_registry: Optional[str] = None 149 | _model_manager: ModelManager 150 | 151 | def __init__( 152 | self, model_registry: Optional[str], keep_all_uploaded: bool = False, clear_all_local: bool = False 153 | ) -> None: 154 | """Initialize with model name. 155 | 156 | Args: 157 | model_registry: Name of the model to upload in format 'organization/teamspace/modelname'. 158 | keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so. 159 | clear_all_local: Whether to clear local models after uploading to the cloud. 160 | """ 161 | if not model_registry: 162 | rank_zero_warn( 163 | "The model is not defined so we will continue with LightningModule names and timestamp of now" 164 | ) 165 | self._datetime_stamp = datetime.now().strftime("%Y%m%d-%H%M") 166 | # remove any / from beginning and end of the name 167 | self.model_registry = model_registry.strip("/") if model_registry else None 168 | self._keep_all_uploaded = keep_all_uploaded 169 | self._clear_all_local = clear_all_local 170 | 171 | try: # authenticate before anything else starts 172 | Auth().authenticate() 173 | except Exception: 174 | raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.") 175 | 176 | self._model_manager = ModelManager() 177 | 178 | @rank_zero_only 179 | def _upload_model(self, trainer: "pl.Trainer", filepath: Union[str, Path], metadata: Optional[dict] = None) -> None: 180 | if not self.model_registry: 181 | raise RuntimeError( 182 | "Model name is not specified neither updated by `setup` method via Trainer." 183 | " Please set the model name before uploading or ensure that `setup` method is called." 184 | ) 185 | model_registry = self.model_registry 186 | if os.path.isfile(filepath): 187 | # parse the file name as version 188 | version, _ = os.path.splitext(os.path.basename(filepath)) 189 | model_registry += f":{version}" 190 | if not metadata: 191 | metadata = {} 192 | # Add the integration name to the metadata 193 | mro = inspect.getmro(type(self)) 194 | abc_index = mro.index(LitModelCheckpointMixin) 195 | ckpt_class = mro[abc_index - 1] 196 | metadata.update({"litModels.integration": ckpt_class.__name__}) 197 | # Add to queue instead of uploading directly 198 | get_model_manager().queue_upload(registry_name=model_registry, filepath=filepath, metadata=metadata) 199 | if self._clear_all_local: 200 | get_model_manager().queue_remove(filepath=filepath, trainer=trainer) 201 | 202 | @rank_zero_only 203 | def _remove_model(self, trainer: "pl.Trainer", filepath: Union[str, Path]) -> None: 204 | """Remove the local version of the model if requested.""" 205 | get_model_manager().queue_remove( 206 | filepath=filepath, 207 | # skip the local removal we put it in the queue right after the upload 208 | trainer=None if self._clear_all_local else trainer, 209 | # skip the cloud removal if we keep all uploaded models 210 | registry_name=None if self._keep_all_uploaded else self.model_registry, 211 | ) 212 | 213 | def default_model_name(self, pl_model: "pl.LightningModule") -> str: 214 | """Generate a default model name based on the class name and timestamp.""" 215 | return pl_model.__class__.__name__ + f"_{self._datetime_stamp}" 216 | 217 | def _update_model_name(self, pl_model: "pl.LightningModule") -> None: 218 | """Update the model name if not already set.""" 219 | count_slashes_in_name = self.model_registry.count("/") if self.model_registry else 0 220 | default_model_name = self.default_model_name(pl_model) 221 | if count_slashes_in_name > 2: 222 | raise ValueError( 223 | f"Invalid model name: '{self.model_registry}'. It should not contain more than two '/' character." 224 | ) 225 | if count_slashes_in_name == 2: 226 | # user has defined the model name in the format 'organization/teamspace/modelname' 227 | return 228 | if count_slashes_in_name == 1: 229 | # user had defined only the teamspace name 230 | self.model_registry = f"{self.model_registry}/{default_model_name}" 231 | elif count_slashes_in_name == 0: 232 | if not self.model_registry: 233 | self.model_registry = default_model_name 234 | teamspace = _resolve_teamspace(None, None, None) 235 | if teamspace: 236 | # case you use default model name and teamspace determined from env. variables aka running in studio 237 | self.model_registry = f"{teamspace.owner.name}/{teamspace.name}/{self.model_registry}" 238 | else: # try to load default users teamspace 239 | ts_names = list(_list_available_teamspaces().keys()) 240 | if len(ts_names) == 1: 241 | self.model_registry = f"{ts_names[0]}/{self.model_registry}" 242 | else: 243 | options = "\n\t".join(ts_names) 244 | raise RuntimeError( 245 | f"Teamspace is not defined and there are multiple teamspaces available:\n{options}" 246 | ) 247 | else: 248 | raise RuntimeError(f"Invalid model name: '{self.model_registry}'") 249 | 250 | 251 | # Create specific implementations 252 | if _LIGHTNING_AVAILABLE: 253 | 254 | class LightningModelCheckpoint(LitModelCheckpointMixin, _LightningModelCheckpoint): 255 | """Lightning ModelCheckpoint with LitModel support. 256 | 257 | Args: 258 | model_registry: Name of the model to upload in format 'organization/teamspace/modelname'. 259 | keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so. 260 | clear_all_local: Whether to clear local models after uploading to the cloud. 261 | *args: Additional arguments to pass to the parent class. 262 | **kwargs: Additional keyword arguments to pass to the parent class. 263 | """ 264 | 265 | def __init__( 266 | self, 267 | *args: Any, 268 | model_name: Optional[str] = None, 269 | model_registry: Optional[str] = None, 270 | keep_all_uploaded: bool = False, 271 | clear_all_local: bool = False, 272 | **kwargs: Any, 273 | ) -> None: 274 | """Initialize the checkpoint with model name and other parameters.""" 275 | _LightningModelCheckpoint.__init__(self, *args, **kwargs) 276 | if model_name is not None: 277 | rank_zero_warn( 278 | "The 'model_name' argument is deprecated and will be removed in a future version." 279 | " Please use 'model_registry' instead." 280 | ) 281 | LitModelCheckpointMixin.__init__( 282 | self, 283 | model_registry=model_registry or model_name, 284 | keep_all_uploaded=keep_all_uploaded, 285 | clear_all_local=clear_all_local, 286 | ) 287 | 288 | def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: 289 | """Setup the checkpoint callback.""" 290 | _LightningModelCheckpoint.setup(self, trainer, pl_module, stage) 291 | self._update_model_name(pl_module) 292 | 293 | def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: 294 | """Extend the save checkpoint method to upload the model.""" 295 | _LightningModelCheckpoint._save_checkpoint(self, trainer, filepath) 296 | if trainer.is_global_zero: # Only upload from the main process 297 | self._upload_model(trainer=trainer, filepath=filepath) 298 | 299 | def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 300 | """Extend the on_fit_end method to ensure all uploads are completed.""" 301 | _LightningModelCheckpoint.on_fit_end(self, trainer, pl_module) 302 | # Wait for all uploads to finish 303 | get_model_manager().shutdown() 304 | 305 | def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: 306 | """Extend the remove checkpoint method to remove the model from the registry.""" 307 | if trainer.is_global_zero: # Only remove from the main process 308 | self._remove_model(trainer=trainer, filepath=filepath) 309 | 310 | 311 | if _PYTORCHLIGHTNING_AVAILABLE: 312 | 313 | class PytorchLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightningModelCheckpoint): 314 | """PyTorch Lightning ModelCheckpoint with LitModel support. 315 | 316 | Args: 317 | model_registry: Name of the model to upload in format 'organization/teamspace/modelname'. 318 | keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so. 319 | clear_all_local: Whether to clear local models after uploading to the cloud. 320 | args: Additional arguments to pass to the parent class. 321 | kwargs: Additional keyword arguments to pass to the parent class. 322 | """ 323 | 324 | def __init__( 325 | self, 326 | *args: Any, 327 | model_name: Optional[str] = None, 328 | model_registry: Optional[str] = None, 329 | keep_all_uploaded: bool = False, 330 | clear_all_local: bool = False, 331 | **kwargs: Any, 332 | ) -> None: 333 | """Initialize the checkpoint with model name and other parameters.""" 334 | _PytorchLightningModelCheckpoint.__init__(self, *args, **kwargs) 335 | if model_name is not None: 336 | rank_zero_warn( 337 | "The 'model_name' argument is deprecated and will be removed in a future version." 338 | " Please use 'model_registry' instead." 339 | ) 340 | LitModelCheckpointMixin.__init__( 341 | self, 342 | model_registry=model_registry or model_name, 343 | keep_all_uploaded=keep_all_uploaded, 344 | clear_all_local=clear_all_local, 345 | ) 346 | 347 | def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: 348 | """Setup the checkpoint callback.""" 349 | _PytorchLightningModelCheckpoint.setup(self, trainer, pl_module, stage) 350 | self._update_model_name(pl_module) 351 | 352 | def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: 353 | """Extend the save checkpoint method to upload the model.""" 354 | _PytorchLightningModelCheckpoint._save_checkpoint(self, trainer, filepath) 355 | if trainer.is_global_zero: # Only upload from the main process 356 | self._upload_model(trainer=trainer, filepath=filepath) 357 | 358 | def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 359 | """Extend the on_fit_end method to ensure all uploads are completed.""" 360 | _PytorchLightningModelCheckpoint.on_fit_end(self, trainer, pl_module) 361 | # Wait for all uploads to finish 362 | get_model_manager().shutdown() 363 | 364 | def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: 365 | """Extend the remove checkpoint method to remove the model from the registry.""" 366 | if trainer.is_global_zero: # Only remove from the main process 367 | self._remove_model(trainer=trainer, filepath=filepath) 368 | -------------------------------------------------------------------------------- /src/litmodels/integrations/duplicate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | from pathlib import Path 5 | from typing import Optional 6 | 7 | from lightning_utilities import module_available 8 | 9 | from litmodels.io import upload_model_files 10 | 11 | if module_available("huggingface_hub"): 12 | from huggingface_hub import snapshot_download 13 | else: 14 | snapshot_download = None 15 | 16 | 17 | def duplicate_hf_model( 18 | hf_model: str, 19 | lit_model: Optional[str] = None, 20 | local_workdir: Optional[str] = None, 21 | verbose: int = 1, 22 | metadata: Optional[dict] = None, 23 | ) -> str: 24 | """Downloads the model from Hugging Face and uploads it to Lightning Cloud. 25 | 26 | Args: 27 | hf_model: The name of the Hugging Face model to duplicate. 28 | lit_model: The name of the Lightning Cloud model to create. 29 | local_workdir: 30 | The local working directory to use for the duplication process. If not set a temp folder will be created. 31 | verbose: Shot a progress bar for the upload. 32 | metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used. 33 | 34 | Returns: 35 | The name of the duplicated model in Lightning Cloud. 36 | """ 37 | if not snapshot_download: 38 | raise ModuleNotFoundError( 39 | "Hugging Face Hub is not installed. Please install it with `pip install huggingface_hub`." 40 | ) 41 | 42 | if not local_workdir: 43 | local_workdir = tempfile.mkdtemp() 44 | local_workdir = Path(local_workdir) 45 | model_name = hf_model.replace("/", "_") 46 | 47 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" 48 | # Download the model from Hugging Face 49 | snapshot_download( 50 | repo_id=hf_model, 51 | revision="main", # Branch/tag/commit 52 | repo_type="model", # Options: "dataset", "model", "space" 53 | local_dir=local_workdir / model_name, # Specify to save in custom location, default is cache 54 | local_dir_use_symlinks=True, # Use symlinks to save disk space 55 | ignore_patterns=[".cache*"], # Exclude certain files if needed 56 | max_workers=os.cpu_count(), # Number of parallel downloads 57 | ) 58 | # prune cache in the downloaded model 59 | for path in local_workdir.rglob(".cache*"): 60 | shutil.rmtree(path) 61 | 62 | # Upload the model to Lightning Cloud 63 | if not lit_model: 64 | lit_model = model_name 65 | if not metadata: 66 | metadata = {} 67 | metadata.update({"litModels.integration": "duplicate_hf_model", "hf_model": hf_model}) 68 | model = upload_model_files(name=lit_model, path=local_workdir / model_name, verbose=verbose, metadata=metadata) 69 | return model.name 70 | -------------------------------------------------------------------------------- /src/litmodels/integrations/imports.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | from lightning_utilities import compare_version, module_available 4 | 5 | _LIGHTNING_AVAILABLE = module_available("lightning") 6 | _LIGHTNING_GREATER_EQUAL_2_5_1 = compare_version("lightning", operator.ge, "2.5.1") 7 | _PYTORCHLIGHTNING_AVAILABLE = module_available("pytorch_lightning") 8 | _PYTORCHLIGHTNING_GREATER_EQUAL_2_5_1 = compare_version("pytorch_lightning", operator.ge, "2.5.1") 9 | -------------------------------------------------------------------------------- /src/litmodels/integrations/mixins.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import json 3 | import tempfile 4 | import warnings 5 | from abc import ABC 6 | from pathlib import Path 7 | from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union 8 | 9 | from lightning_utilities.core.rank_zero import rank_zero_warn 10 | 11 | from litmodels.io.cloud import download_model_files, upload_model_files 12 | from litmodels.io.utils import dump_pickle, load_pickle 13 | 14 | if TYPE_CHECKING: 15 | import torch 16 | 17 | 18 | class ModelRegistryMixin(ABC): 19 | """Mixin for model registry integration.""" 20 | 21 | def upload_model( 22 | self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None 23 | ) -> None: 24 | """Push the model to the registry. 25 | 26 | Args: 27 | name: The name of the model. If not use the class name. 28 | version: The version of the model. If None, the latest version is used. 29 | temp_folder: The temporary folder to save the model. If None, a default temporary folder is used. 30 | """ 31 | 32 | @classmethod 33 | def download_model( 34 | cls, name: str, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None 35 | ) -> object: 36 | """Pull the model from the registry. 37 | 38 | Args: 39 | name: The name of the model. 40 | version: The version of the model. If None, the latest version is used. 41 | temp_folder: The temporary folder to save the model. If None, a default temporary folder is used. 42 | """ 43 | 44 | def _setup( 45 | self, name: Optional[str] = None, temp_folder: Union[str, Path, None] = None 46 | ) -> Tuple[str, str, Union[str, Path]]: 47 | """Parse and validate the model name and temporary folder.""" 48 | if name is None: 49 | name = model_name = self.__class__.__name__ 50 | elif ":" in name: 51 | raise ValueError(f"Invalid model name: '{name}'. It should not contain ':' associated with version.") 52 | else: 53 | model_name = name.split("/")[-1] 54 | if temp_folder is None: 55 | temp_folder = tempfile.mkdtemp() 56 | return name, model_name, temp_folder 57 | 58 | def _upload_model_files( 59 | self, name: str, path: Union[str, Path, List[Union[str, Path]]], metadata: Optional[dict] = None 60 | ) -> None: 61 | """Upload the model files to the registry.""" 62 | if not metadata: 63 | metadata = {} 64 | # Add the integration name to the metadata 65 | mro = inspect.getmro(type(self)) 66 | abc_index = mro.index(ModelRegistryMixin) 67 | mixin_class = mro[abc_index - 1] 68 | metadata.update({"litModels.integration": mixin_class.__name__}) 69 | upload_model_files(name=name, path=path, metadata=metadata) 70 | 71 | 72 | class PickleRegistryMixin(ModelRegistryMixin): 73 | """Mixin for pickle registry integration.""" 74 | 75 | def upload_model( 76 | self, 77 | name: Optional[str] = None, 78 | version: Optional[str] = None, 79 | temp_folder: Union[str, Path, None] = None, 80 | metadata: Optional[dict] = None, 81 | ) -> None: 82 | """Push the model to the registry. 83 | 84 | Args: 85 | name: The name of the model. If not use the class name. 86 | version: The version of the model. If None, the latest version is used. 87 | temp_folder: The temporary folder to save the model. If None, a default temporary folder is used. 88 | metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used. 89 | """ 90 | name, model_name, temp_folder = self._setup(name, temp_folder) 91 | pickle_path = Path(temp_folder) / f"{model_name}.pkl" 92 | dump_pickle(model=self, path=pickle_path) 93 | if version: 94 | name = f"{name}:{version}" 95 | self._upload_model_files(name=name, path=pickle_path, metadata=metadata) 96 | 97 | @classmethod 98 | def download_model( 99 | cls, name: str, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None 100 | ) -> object: 101 | """Pull the model from the registry. 102 | 103 | Args: 104 | name: The name of the model. 105 | version: The version of the model. If None, the latest version is used. 106 | temp_folder: The temporary folder to save the model. If None, a default temporary folder is used. 107 | """ 108 | if temp_folder is None: 109 | temp_folder = tempfile.mkdtemp() 110 | model_registry = f"{name}:{version}" if version else name 111 | files = download_model_files(name=model_registry, download_dir=temp_folder) 112 | pkl_files = [f for f in files if f.endswith(".pkl")] 113 | if not pkl_files: 114 | raise RuntimeError(f"No pickle file found for model: {model_registry} with {files}") 115 | if len(pkl_files) > 1: 116 | raise RuntimeError(f"Multiple pickle files found for model: {model_registry} with {pkl_files}") 117 | pkl_path = Path(temp_folder) / pkl_files[0] 118 | obj = load_pickle(path=pkl_path) 119 | if not isinstance(obj, cls): 120 | raise RuntimeError(f"Unpickled object is not of type {cls.__name__}: {type(obj)}") 121 | return obj 122 | 123 | 124 | class PyTorchRegistryMixin(ModelRegistryMixin): 125 | """Mixin for PyTorch model registry integration.""" 126 | 127 | def __new__(cls, *args: Any, **kwargs: Any) -> "torch.nn.Module": 128 | """Create a new instance of the class without calling __init__.""" 129 | instance = super().__new__(cls) 130 | 131 | # Get __init__ signature excluding 'self' 132 | init_sig = inspect.signature(cls.__init__) 133 | params = list(init_sig.parameters.values())[1:] # Skip self 134 | 135 | # Create temporary signature for binding 136 | temp_sig = init_sig.replace(parameters=params) 137 | 138 | # Bind and apply defaults 139 | bound_args = temp_sig.bind(*args, **kwargs) 140 | bound_args.apply_defaults() 141 | 142 | # Store unified kwargs 143 | instance.__init_kwargs = bound_args.arguments 144 | return instance 145 | 146 | def upload_model( 147 | self, 148 | name: Optional[str] = None, 149 | version: Optional[str] = None, 150 | temp_folder: Union[str, Path, None] = None, 151 | metadata: Optional[dict] = None, 152 | ) -> None: 153 | """Push the model to the registry. 154 | 155 | Args: 156 | name: The name of the model. If not use the class name. 157 | version: The version of the model. If None, the latest version is used. 158 | temp_folder: The temporary folder to save the model. If None, a default temporary folder is used. 159 | metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used. 160 | """ 161 | import torch 162 | 163 | # Ensure that the model is in evaluation mode 164 | if not isinstance(self, torch.nn.Module): 165 | raise TypeError(f"The model must be a PyTorch `nn.Module` but got: {type(self)}") 166 | 167 | name, model_name, temp_folder = self._setup(name, temp_folder) 168 | 169 | init_kwargs_path = None 170 | if self.__init_kwargs: 171 | try: 172 | # Save the model arguments to a JSON file 173 | init_kwargs_path = Path(temp_folder) / f"{model_name}__init_kwargs.json" 174 | with open(init_kwargs_path, "w") as fp: 175 | json.dump(self.__init_kwargs, fp) 176 | except Exception as ex: 177 | raise RuntimeError( 178 | f"Failed to save model arguments: {ex}." 179 | " Ensure the model's arguments are JSON serializable or use `PickleRegistryMixin`." 180 | ) from ex 181 | elif not hasattr(self, "__init_kwargs"): 182 | rank_zero_warn( 183 | "The child class is missing `__init_kwargs`." 184 | " Ensure `PyTorchRegistryMixin` is first in the inheritance order" 185 | " or call `PyTorchRegistryMixin.__init__` explicitly in the child class." 186 | ) 187 | 188 | torch_state_dict_path = Path(temp_folder) / f"{model_name}.pth" 189 | torch.save(self.state_dict(), torch_state_dict_path) 190 | model_registry = f"{name}:{version}" if version else name 191 | # todo: consider creating another temp folder and copying these two files 192 | # todo: updating SDK to support uploading just specific files 193 | uploaded_files = [torch_state_dict_path] 194 | if init_kwargs_path: 195 | uploaded_files.append(init_kwargs_path) 196 | self._upload_model_files(name=model_registry, path=uploaded_files, metadata=metadata) 197 | 198 | @classmethod 199 | def download_model( 200 | cls, 201 | name: str, 202 | version: Optional[str] = None, 203 | temp_folder: Union[str, Path, None] = None, 204 | torch_load_kwargs: Optional[dict] = None, 205 | ) -> "torch.nn.Module": 206 | """Pull the model from the registry. 207 | 208 | Args: 209 | name: The name of the model. 210 | version: The version of the model. If None, the latest version is used. 211 | temp_folder: The temporary folder to save the model. If None, a default temporary folder is used. 212 | torch_load_kwargs: Additional arguments to pass to `torch.load()`. 213 | """ 214 | import torch 215 | 216 | if temp_folder is None: 217 | temp_folder = tempfile.mkdtemp() 218 | model_registry = f"{name}:{version}" if version else name 219 | files = download_model_files(name=model_registry, download_dir=temp_folder) 220 | 221 | torch_files = [f for f in files if f.endswith(".pth")] 222 | if not torch_files: 223 | raise RuntimeError(f"No torch file found for model: {model_registry} with {files}") 224 | if len(torch_files) > 1: 225 | raise RuntimeError(f"Multiple torch files found for model: {model_registry} with {torch_files}") 226 | state_dict_path = Path(temp_folder) / torch_files[0] 227 | # ignore future warning about changed default 228 | with warnings.catch_warnings(): 229 | warnings.simplefilter("ignore", category=FutureWarning) 230 | state_dict = torch.load(state_dict_path, **(torch_load_kwargs if torch_load_kwargs else {})) 231 | 232 | init_files = [fp for fp in files if fp.endswith("__init_kwargs.json")] 233 | if not init_files: 234 | init_kwargs = {} 235 | elif len(init_files) > 1: 236 | raise RuntimeError(f"Multiple init files found for model: {model_registry} with {init_files}") 237 | else: 238 | init_kwargs_path = Path(temp_folder) / init_files[0] 239 | with open(init_kwargs_path) as fp: 240 | init_kwargs = json.load(fp) 241 | 242 | # Create a new model instance without calling __init__ 243 | instance = cls(**init_kwargs) 244 | if not isinstance(instance, torch.nn.Module): 245 | raise TypeError(f"The model must be a PyTorch `nn.Module` but got: {type(instance)}") 246 | # Now load the state dict on the instance 247 | instance.load_state_dict(state_dict, strict=True) 248 | return instance 249 | -------------------------------------------------------------------------------- /src/litmodels/io/__init__.py: -------------------------------------------------------------------------------- 1 | """Root package for Input/output.""" 2 | 3 | from litmodels.io.cloud import download_model_files, upload_model_files # noqa: F401 4 | from litmodels.io.gateway import download_model, load_model, save_model, upload_model 5 | 6 | __all__ = ["download_model", "upload_model", "upload_model_files", "load_model", "save_model"] 7 | -------------------------------------------------------------------------------- /src/litmodels/io/cloud.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # http://www.apache.org/licenses/LICENSE-2.0 4 | # 5 | from pathlib import Path 6 | from typing import TYPE_CHECKING, Dict, List, Optional, Union 7 | 8 | from lightning_sdk.lightning_cloud.env import LIGHTNING_CLOUD_URL 9 | from lightning_sdk.models import _extend_model_name_with_teamspace, _parse_org_teamspace_model_version 10 | from lightning_sdk.models import delete_model as sdk_delete_model 11 | from lightning_sdk.models import download_model as sdk_download_model 12 | from lightning_sdk.models import upload_model as sdk_upload_model 13 | 14 | import litmodels 15 | 16 | if TYPE_CHECKING: 17 | from lightning_sdk.models import UploadedModelInfo 18 | 19 | 20 | _SHOWED_MODEL_LINKS = [] 21 | 22 | 23 | def _print_model_link(name: str, verbose: Union[bool, int]) -> None: 24 | """Print a link to the uploaded model. 25 | 26 | Args: 27 | name: Name of the model. 28 | verbose: Whether to print the link: 29 | 30 | - If set to 0, no link will be printed. 31 | - If set to 1, the link will be printed only once. 32 | - If set to 2, the link will be printed every time. 33 | """ 34 | name = _extend_model_name_with_teamspace(name) 35 | org_name, teamspace_name, model_name, _ = _parse_org_teamspace_model_version(name) 36 | 37 | url = f"{LIGHTNING_CLOUD_URL}/{org_name}/{teamspace_name}/models/{model_name}" 38 | msg = f"Model uploaded successfully. Link to the model: '{url}'" 39 | if int(verbose) > 1: 40 | print(msg) 41 | elif url not in _SHOWED_MODEL_LINKS: 42 | print(msg) 43 | _SHOWED_MODEL_LINKS.append(url) 44 | 45 | 46 | def upload_model_files( 47 | name: str, 48 | path: Union[str, Path, List[Union[str, Path]]], 49 | progress_bar: bool = True, 50 | cloud_account: Optional[str] = None, 51 | verbose: Union[bool, int] = 1, 52 | metadata: Optional[Dict[str, str]] = None, 53 | ) -> "UploadedModelInfo": 54 | """Upload a local checkpoint file to the model store. 55 | 56 | Args: 57 | name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname' 58 | where entity is either your username or the name of an organization you are part of. 59 | path: Path to the model file to upload. 60 | progress_bar: Whether to show a progress bar for the upload. 61 | cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined 62 | automatically. 63 | verbose: Whether to print a link to the uploaded model. If set to 0, no link will be printed. 64 | metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used. 65 | 66 | """ 67 | if not metadata: 68 | metadata = {} 69 | metadata.update({"litModels": litmodels.__version__}) 70 | info = sdk_upload_model( 71 | name=name, 72 | path=path, 73 | progress_bar=progress_bar, 74 | cloud_account=cloud_account, 75 | metadata=metadata, 76 | ) 77 | if verbose: 78 | _print_model_link(name, verbose) 79 | return info 80 | 81 | 82 | def download_model_files( 83 | name: str, 84 | download_dir: Union[str, Path] = ".", 85 | progress_bar: bool = True, 86 | ) -> Union[str, List[str]]: 87 | """Download a checkpoint from the model store. 88 | 89 | Args: 90 | name: Name of the model to download. Must be in the format 'organization/teamspace/modelname' 91 | where entity is either your username or the name of an organization you are part of. 92 | download_dir: A path to directory where the model should be downloaded. Defaults 93 | to the current working directory. 94 | progress_bar: Whether to show a progress bar for the download. 95 | 96 | Returns: 97 | The absolute path to the downloaded model file or folder. 98 | """ 99 | return sdk_download_model( 100 | name=name, 101 | download_dir=download_dir, 102 | progress_bar=progress_bar, 103 | ) 104 | 105 | 106 | def _list_available_teamspaces() -> Dict[str, dict]: 107 | """List available teamspaces for the authenticated user. 108 | 109 | Returns: 110 | Dict with teamspace names as keys and their details as values. 111 | """ 112 | from lightning_sdk.api import OrgApi, UserApi 113 | from lightning_sdk.utils import resolve as sdk_resolvers 114 | 115 | org_api = OrgApi() 116 | user = sdk_resolvers._get_authed_user() 117 | teamspaces = {} 118 | for ts in UserApi()._get_all_teamspace_memberships(""): 119 | if ts.owner_type == "organization": 120 | org = org_api._get_org_by_id(ts.owner_id) 121 | teamspaces[f"{org.name}/{ts.name}"] = {"name": ts.name, "org": org.name} 122 | elif ts.owner_type == "user": # todo: check also the name 123 | teamspaces[f"{user.name}/{ts.name}"] = {"name": ts.name, "user": user} 124 | else: 125 | raise RuntimeError(f"Unknown organization type {ts.organization_type}") 126 | return teamspaces 127 | 128 | 129 | def delete_model_version( 130 | name: str, 131 | version: Optional[str] = None, 132 | ) -> None: 133 | """Delete a model version from the model store. 134 | 135 | Args: 136 | name: Name of the model to delete. Must be in the format 'organization/teamspace/modelname' 137 | where entity is either your username or the name of an organization you are part of. 138 | version: Version of the model to delete. If not provided, all versions will be deleted. 139 | """ 140 | sdk_delete_model(name=f"{name}:{version}") 141 | -------------------------------------------------------------------------------- /src/litmodels/io/gateway.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from pathlib import Path 4 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union 5 | 6 | from litmodels.io.cloud import download_model_files, upload_model_files 7 | from litmodels.io.utils import _KERAS_AVAILABLE, _PYTORCH_AVAILABLE, dump_pickle, load_pickle 8 | 9 | if _PYTORCH_AVAILABLE: 10 | import torch 11 | 12 | if _KERAS_AVAILABLE: 13 | from tensorflow import keras 14 | 15 | if TYPE_CHECKING: 16 | from lightning_sdk.models import UploadedModelInfo 17 | 18 | 19 | def upload_model( 20 | name: str, 21 | model: Union[str, Path], 22 | progress_bar: bool = True, 23 | cloud_account: Optional[str] = None, 24 | verbose: Union[bool, int] = 1, 25 | metadata: Optional[Dict[str, str]] = None, 26 | ) -> "UploadedModelInfo": 27 | """Upload a checkpoint to the model store. 28 | 29 | Args: 30 | name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname' 31 | where entity is either your username or the name of an organization you are part of. 32 | model: The model to upload. Can be a path to a checkpoint file or a folder. 33 | progress_bar: Whether to show a progress bar for the upload. 34 | cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined 35 | automatically. 36 | verbose: Whether to print some additional information about the uploaded model. 37 | metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used. 38 | 39 | """ 40 | if not isinstance(model, (str, Path)): 41 | raise ValueError( 42 | "The `model` argument should be a path to a file or folder, not an python object." 43 | " For smooth integrations with PyTorch model, Lightning model and many more, use `save_model` instead." 44 | ) 45 | 46 | return upload_model_files( 47 | path=model, 48 | name=name, 49 | progress_bar=progress_bar, 50 | cloud_account=cloud_account, 51 | verbose=verbose, 52 | metadata=metadata, 53 | ) 54 | 55 | 56 | def save_model( 57 | name: str, 58 | model: Union["torch.nn.Module", Any], 59 | progress_bar: bool = True, 60 | cloud_account: Optional[str] = None, 61 | staging_dir: Optional[str] = None, 62 | verbose: Union[bool, int] = 1, 63 | metadata: Optional[Dict[str, str]] = None, 64 | ) -> "UploadedModelInfo": 65 | """Upload a checkpoint to the model store. 66 | 67 | Args: 68 | name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname' 69 | where entity is either your username or the name of an organization you are part of. 70 | model: The model to upload. Can be a PyTorch model, or a Lightning model a. 71 | progress_bar: Whether to show a progress bar for the upload. 72 | cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined 73 | automatically. 74 | staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will 75 | be created and used. 76 | verbose: Whether to print some additional information about the uploaded model. 77 | metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used. 78 | 79 | """ 80 | if isinstance(model, (str, Path)): 81 | raise ValueError( 82 | "The `model` argument should be a PyTorch model or a Lightning model, not a path to a file." 83 | " With file or folder path use `upload_model` instead." 84 | ) 85 | 86 | if not staging_dir: 87 | staging_dir = tempfile.mkdtemp() 88 | # if LightningModule and isinstance(model, LightningModule): 89 | # path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt") 90 | # model.save_checkpoint(path) 91 | if _PYTORCH_AVAILABLE and isinstance(model, torch.jit.ScriptModule): 92 | path = os.path.join(staging_dir, f"{model.__class__.__name__}.ts") 93 | model.save(path) 94 | elif _PYTORCH_AVAILABLE and isinstance(model, torch.nn.Module): 95 | path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth") 96 | torch.save(model.state_dict(), path) 97 | elif _KERAS_AVAILABLE and isinstance(model, keras.models.Model): 98 | path = os.path.join(staging_dir, f"{model.__class__.__name__}.keras") 99 | model.save(path) 100 | else: 101 | path = os.path.join(staging_dir, f"{model.__class__.__name__}.pkl") 102 | dump_pickle(model=model, path=path) 103 | 104 | if not metadata: 105 | metadata = {} 106 | metadata.update({"litModels.integration": "save_model"}) 107 | 108 | return upload_model( 109 | model=path, 110 | name=name, 111 | progress_bar=progress_bar, 112 | cloud_account=cloud_account, 113 | verbose=verbose, 114 | metadata=metadata, 115 | ) 116 | 117 | 118 | def download_model( 119 | name: str, 120 | download_dir: Union[str, Path] = ".", 121 | progress_bar: bool = True, 122 | ) -> Union[str, List[str]]: 123 | """Download a checkpoint from the model store. 124 | 125 | Args: 126 | name: Name of the model to download. Must be in the format 'organization/teamspace/modelname' 127 | where entity is either your username or the name of an organization you are part of. 128 | download_dir: A path to directory where the model should be downloaded. Defaults 129 | to the current working directory. 130 | progress_bar: Whether to show a progress bar for the download. 131 | 132 | Returns: 133 | The absolute path to the downloaded model file or folder. 134 | """ 135 | return download_model_files( 136 | name=name, 137 | download_dir=download_dir, 138 | progress_bar=progress_bar, 139 | ) 140 | 141 | 142 | def load_model(name: str, download_dir: str = ".") -> Any: 143 | """Download a model from the model store and load it into memory. 144 | 145 | Args: 146 | name: Name of the model to download. Must be in the format 'organization/teamspace/modelname' 147 | where entity is either your username or the name of an organization you are part of. 148 | download_dir: A path to directory where the model should be downloaded. Defaults 149 | to the current working directory. 150 | 151 | Returns: 152 | The loaded model. 153 | """ 154 | download_paths = download_model(name=name, download_dir=download_dir) 155 | # filter out all Markdown, TXT and RST files 156 | download_paths = [p for p in download_paths if Path(p).suffix.lower() not in {".md", ".txt", ".rst"}] 157 | if len(download_paths) > 1: 158 | raise NotImplementedError("Downloaded model with multiple files is not supported yet.") 159 | model_path = Path(download_dir) / download_paths[0] 160 | if model_path.suffix.lower() == ".ts": 161 | return torch.jit.load(model_path) 162 | if model_path.suffix.lower() == ".keras": 163 | return keras.models.load_model(model_path) 164 | if model_path.suffix.lower() == ".pkl": 165 | return load_pickle(path=model_path) 166 | raise NotImplementedError(f"Loading model from {model_path.suffix} is not supported yet.") 167 | -------------------------------------------------------------------------------- /src/litmodels/io/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | from typing import Any, Union 4 | 5 | from lightning_utilities import module_available 6 | from lightning_utilities.core.imports import RequirementCache 7 | 8 | _JOBLIB_AVAILABLE = module_available("joblib") 9 | _PYTORCH_AVAILABLE = module_available("torch") 10 | _TENSORFLOW_AVAILABLE = module_available("tensorflow") 11 | _KERAS_AVAILABLE = RequirementCache("tensorflow >=2.0.0") 12 | 13 | if _JOBLIB_AVAILABLE: 14 | import joblib 15 | 16 | 17 | def dump_pickle(model: Any, path: Union[str, Path]) -> None: 18 | """Dump a model to a pickle file. 19 | 20 | Args: 21 | model: The model to be pickled. 22 | path: The path where the model will be saved. 23 | """ 24 | if _JOBLIB_AVAILABLE: 25 | joblib.dump(model, filename=path, compress=7) 26 | else: 27 | with open(path, "wb") as fp: 28 | pickle.dump(model, fp, protocol=pickle.HIGHEST_PROTOCOL) 29 | 30 | 31 | def load_pickle(path: Union[str, Path]) -> Any: 32 | """Load a model from a pickle file. 33 | 34 | Args: 35 | path: The path to the pickle file. 36 | 37 | Returns: 38 | The unpickled model. 39 | """ 40 | if _JOBLIB_AVAILABLE: 41 | return joblib.load(path) 42 | with open(path, "rb") as fp: 43 | return pickle.load(fp) 44 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Configure local testing.""" 2 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Pytest configuration for integration tests.""" 2 | 3 | import pytest 4 | 5 | from litmodels.integrations.checkpoints import get_model_manager 6 | 7 | 8 | @pytest.fixture(autouse=True) 9 | def reset_model_manager(): 10 | get_model_manager.cache_clear() 11 | # Optionally, call it once to initialize immediately 12 | return get_model_manager() 13 | -------------------------------------------------------------------------------- /tests/integrations/__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from litmodels.integrations.imports import ( 4 | _LIGHTNING_AVAILABLE, 5 | _LIGHTNING_GREATER_EQUAL_2_5_1, 6 | _PYTORCHLIGHTNING_AVAILABLE, 7 | _PYTORCHLIGHTNING_GREATER_EQUAL_2_5_1, 8 | ) 9 | 10 | _SKIP_IF_LIGHTNING_MISSING = pytest.mark.skipif(not _LIGHTNING_AVAILABLE, reason="Lightning not available") 11 | _SKIP_IF_LIGHTNING_BELLOW_2_5_1 = pytest.mark.skipif( 12 | not _LIGHTNING_GREATER_EQUAL_2_5_1, reason="Lightning without integration introduced in 2.5.1" 13 | ) 14 | _SKIP_IF_PYTORCHLIGHTNING_MISSING = pytest.mark.skipif( 15 | not _PYTORCHLIGHTNING_AVAILABLE, reason="PyTorch Lightning not available" 16 | ) 17 | _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1 = pytest.mark.skipif( 18 | not _PYTORCHLIGHTNING_GREATER_EQUAL_2_5_1, reason="PyTorch Lightning without integration introduced in 2.5.1" 19 | ) 20 | 21 | LIT_ORG = "lightning-ai" 22 | LIT_TEAMSPACE = "OSS | litModels" 23 | -------------------------------------------------------------------------------- /tests/integrations/test_checkpoints.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import re 3 | from unittest import mock 4 | 5 | import pytest 6 | 7 | import litmodels 8 | from tests.integrations import _SKIP_IF_LIGHTNING_MISSING, _SKIP_IF_PYTORCHLIGHTNING_MISSING 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "importing", 13 | [ 14 | pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_MISSING), 15 | pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_MISSING), 16 | ], 17 | ) 18 | @pytest.mark.parametrize( 19 | "model_name", [None, "org-name/teamspace/model-name", "model-in-studio", "model-user-only-project"] 20 | ) 21 | @pytest.mark.parametrize("clear_all_local", [True, False]) 22 | @pytest.mark.parametrize("keep_all_uploaded", [True, False]) 23 | @mock.patch("litmodels.io.cloud.sdk_delete_model") 24 | @mock.patch("litmodels.io.cloud.sdk_upload_model") 25 | @mock.patch("litmodels.integrations.checkpoints.Auth") 26 | def test_lightning_checkpoint_callback( 27 | mock_auth, 28 | mock_upload_model, 29 | mock_delete_model, 30 | monkeypatch, 31 | importing, 32 | model_name, 33 | clear_all_local, 34 | keep_all_uploaded, 35 | tmp_path, 36 | ): 37 | if importing == "lightning": 38 | from lightning.pytorch import Trainer 39 | from lightning.pytorch.callbacks import ModelCheckpoint 40 | from lightning.pytorch.demos.boring_classes import BoringModel 41 | 42 | from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint 43 | elif importing == "pytorch_lightning": 44 | from pytorch_lightning import Trainer 45 | from pytorch_lightning.callbacks import ModelCheckpoint 46 | from pytorch_lightning.demos.boring_classes import BoringModel 47 | 48 | from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint 49 | 50 | # Validate inheritance 51 | assert issubclass(LitModelCheckpoint, ModelCheckpoint) 52 | 53 | ckpt_args = {"clear_all_local": clear_all_local, "keep_all_uploaded": keep_all_uploaded} 54 | if model_name: 55 | ckpt_args.update({"model_registry": model_name}) 56 | 57 | all_model_registry = { 58 | "org-name/teamspace/model-name": {"org": "org-name", "teamspace": "teamspace", "model": "model-name"}, 59 | "model-in-studio": {"org": "my-org", "teamspace": "dream-team", "model": "model-in-studio"}, 60 | "model-user-only-project": {"org": "my-org", "teamspace": "default-ts", "model": "model-user-only-project"}, 61 | } 62 | expected_boring_model = "BoringModel_20250102-1213" 63 | expected_model_registry = all_model_registry.get( 64 | model_name, 65 | {"org": "org-name", "teamspace": "teamspace", "model": expected_boring_model}, 66 | ) 67 | expected_org = expected_model_registry["org"] 68 | expected_teamspace = expected_model_registry["teamspace"] 69 | expected_model = expected_model_registry["model"] 70 | mock_upload_model.return_value.name = f"{expected_org}/{expected_teamspace}/{expected_model}" 71 | monkeypatch.setattr( 72 | "litmodels.integrations.checkpoints.LitModelCheckpointMixin.default_model_name", 73 | mock.MagicMock(return_value=expected_boring_model), 74 | ) 75 | if model_name is None or model_name == "model-in-studio": 76 | mock_teamspace = mock.Mock(owner=mock.Mock()) 77 | mock_teamspace.owner.name = expected_org 78 | mock_teamspace.name = expected_teamspace 79 | 80 | monkeypatch.setattr( 81 | "litmodels.integrations.checkpoints._resolve_teamspace", mock.MagicMock(return_value=mock_teamspace) 82 | ) 83 | elif model_name == "model-user-only-project": 84 | monkeypatch.setattr("litmodels.integrations.checkpoints._resolve_teamspace", mock.MagicMock(return_value=None)) 85 | monkeypatch.setattr( 86 | "litmodels.integrations.checkpoints._list_available_teamspaces", 87 | mock.MagicMock(return_value={f"{expected_org}/{expected_teamspace}": {}}), 88 | ) 89 | 90 | # mocking the trainer delete checkpoint removal 91 | mock_remove_ckpt = mock.Mock() 92 | # setting the Trainer and custom checkpointing 93 | trainer = Trainer( 94 | max_epochs=2, 95 | callbacks=LitModelCheckpoint(**ckpt_args), 96 | ) 97 | trainer.strategy.remove_checkpoint = mock_remove_ckpt 98 | trainer.fit(BoringModel()) 99 | 100 | assert mock_auth.call_count == 1 101 | assert mock_upload_model.call_args_list == [ 102 | mock.call( 103 | name=f"{expected_org}/{expected_teamspace}/{expected_model}:{v}", 104 | path=mock.ANY, 105 | progress_bar=True, 106 | cloud_account=None, 107 | metadata={"litModels.integration": LitModelCheckpoint.__name__, "litModels": litmodels.__version__}, 108 | ) 109 | for v in ("epoch=0-step=64", "epoch=1-step=128") 110 | ] 111 | expected_local_removals = 2 if clear_all_local else 1 112 | assert mock_remove_ckpt.call_count == expected_local_removals 113 | 114 | expected_cloud_removals = 0 if keep_all_uploaded else 1 115 | assert mock_delete_model.call_count == expected_cloud_removals 116 | if expected_cloud_removals: 117 | mock_delete_model.assert_called_once_with( 118 | name=f"{expected_org}/{expected_teamspace}/{expected_model}:epoch=0-step=64" 119 | ) 120 | 121 | # Verify paths match the expected pattern 122 | for call_args in mock_upload_model.call_args_list: 123 | path = call_args[1]["path"] 124 | assert re.match(r".*[/\\]lightning_logs[/\\]version_\d+[/\\]checkpoints[/\\]epoch=\d+-step=\d+\.ckpt$", path) 125 | 126 | 127 | @pytest.mark.parametrize( 128 | "importing", 129 | [ 130 | pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_MISSING), 131 | pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_MISSING), 132 | ], 133 | ) 134 | @mock.patch("litmodels.integrations.checkpoints.Auth") 135 | def test_lightning_checkpointing_pickleable(mock_auth, importing): 136 | if importing == "lightning": 137 | from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint 138 | elif importing == "pytorch_lightning": 139 | from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint 140 | 141 | ckpt = LitModelCheckpoint(model_registry="org-name/teamspace/model-name") 142 | assert mock_auth.call_count == 1 143 | pickle.dumps(ckpt) 144 | -------------------------------------------------------------------------------- /tests/integrations/test_duplicate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest import mock 3 | 4 | from litmodels.integrations.duplicate import duplicate_hf_model 5 | 6 | 7 | @mock.patch("litmodels.integrations.duplicate.snapshot_download") 8 | @mock.patch("litmodels.integrations.duplicate.upload_model_files") 9 | def test_duplicate_hf_model(mock_upload_model, mock_snapshot_download, tmp_path): 10 | """Verify that the HF model can be duplicated to the teamspace""" 11 | 12 | hf_model = "google/t5-efficient-tiny" 13 | # model name with random hash 14 | model_name = f"litmodels_hf_model+{os.urandom(8).hex()}" 15 | duplicate_hf_model(hf_model=hf_model, lit_model=model_name, local_workdir=str(tmp_path)) 16 | 17 | mock_snapshot_download.assert_called_with( 18 | repo_id=hf_model, 19 | revision="main", 20 | repo_type="model", 21 | local_dir=tmp_path / hf_model.replace("/", "_"), 22 | local_dir_use_symlinks=True, 23 | ignore_patterns=[".cache*"], 24 | max_workers=os.cpu_count(), 25 | ) 26 | mock_upload_model.assert_called_with( 27 | name=f"{model_name}", 28 | path=tmp_path / hf_model.replace("/", "_"), 29 | metadata={"hf_model": hf_model, "litModels.integration": "duplicate_hf_model"}, 30 | verbose=1, 31 | ) 32 | -------------------------------------------------------------------------------- /tests/integrations/test_mixins.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | import pytest 4 | import torch 5 | from torch import nn 6 | 7 | from litmodels.integrations.mixins import PickleRegistryMixin, PyTorchRegistryMixin 8 | 9 | 10 | class DummyModel(PickleRegistryMixin): 11 | def __init__(self, value): 12 | self.value = value 13 | 14 | def __eq__(self, other): 15 | return isinstance(other, DummyModel) and self.value == other.value 16 | 17 | 18 | @mock.patch("litmodels.integrations.mixins.upload_model_files") 19 | @mock.patch("litmodels.integrations.mixins.download_model_files") 20 | def test_pickle_push_and_pull(mock_download_model, mock_upload_model, tmp_path): 21 | # Create an instance of DummyModel and call push_to_registry. 22 | dummy = DummyModel(42) 23 | dummy.upload_model(version="v1", temp_folder=str(tmp_path)) 24 | # The expected registry name is "dummy_model:v1" and the file should be placed in the temp folder. 25 | expected_path = tmp_path / "DummyModel.pkl" 26 | mock_upload_model.assert_called_once_with( 27 | name="DummyModel:v1", path=expected_path, metadata={"litModels.integration": "PickleRegistryMixin"} 28 | ) 29 | 30 | # Set the mock to return the full path to the pickle file. 31 | mock_download_model.return_value = ["DummyModel.pkl"] 32 | # Call pull_from_registry and load the DummyModel instance. 33 | loaded_dummy = DummyModel.download_model(name="dummy_model", version="v1", temp_folder=str(tmp_path)) 34 | # Verify that the unpickled instance has the expected value. 35 | assert loaded_dummy.value == 42 36 | 37 | 38 | class DummyTorchModelFirst(PyTorchRegistryMixin, nn.Module): 39 | def __init__(self, input_size: int, output_size: int = 10): 40 | # PyTorchRegistryMixin.__init__ will capture these arguments 41 | super().__init__() 42 | self.fc = nn.Linear(input_size, output_size) 43 | 44 | def forward(self, x): 45 | x = x.view(x.size(0), -1) 46 | return self.fc(x) 47 | 48 | 49 | class DummyTorchModelSecond(nn.Module, PyTorchRegistryMixin): 50 | def __init__(self, input_size: int, output_size: int = 10): 51 | PyTorchRegistryMixin.__init__(input_size, output_size) 52 | super().__init__() 53 | self.fc = nn.Linear(input_size, output_size) 54 | 55 | def forward(self, x): 56 | x = x.view(x.size(0), -1) 57 | return self.fc(x) 58 | 59 | 60 | @pytest.mark.parametrize("torch_class", [DummyTorchModelFirst, DummyTorchModelSecond]) 61 | @mock.patch("litmodels.integrations.mixins.upload_model_files") 62 | @mock.patch("litmodels.integrations.mixins.download_model_files") 63 | def test_pytorch_push_and_pull(mock_download_model, mock_upload_model, torch_class, tmp_path): 64 | # Create an instance, push the model and record its forward output. 65 | dummy = torch_class(784) 66 | dummy.eval() 67 | input_tensor = torch.randn(1, 784) 68 | output_before = dummy(input_tensor) 69 | 70 | torch_file = f"{dummy.__class__.__name__}.pth" 71 | torch.save(dummy.state_dict(), tmp_path / torch_file) 72 | json_file = f"{dummy.__class__.__name__}__init_kwargs.json" 73 | json_path = tmp_path / json_file 74 | with open(json_path, "w") as fp: 75 | fp.write('{"input_size": 784, "output_size": 10}') 76 | 77 | dummy.upload_model(temp_folder=str(tmp_path)) 78 | mock_upload_model.assert_called_once_with( 79 | name=torch_class.__name__, 80 | path=[tmp_path / f"{torch_class.__name__}.pth", json_path], 81 | metadata={"litModels.integration": "PyTorchRegistryMixin"}, 82 | ) 83 | 84 | # Prepare mocking for pull_from_registry. 85 | mock_download_model.return_value = [torch_file, json_file] 86 | loaded_dummy = torch_class.download_model(name=torch_class.__name__, temp_folder=str(tmp_path)) 87 | loaded_dummy.eval() 88 | output_after = loaded_dummy(input_tensor) 89 | 90 | assert isinstance(loaded_dummy, torch_class) 91 | # Compare the outputs as a verification. 92 | assert torch.allclose(output_before, output_after), "Loaded model output differs from original." 93 | -------------------------------------------------------------------------------- /tests/integrations/test_real_cloud.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | from contextlib import redirect_stdout 4 | from io import StringIO 5 | from typing import Optional 6 | 7 | import pytest 8 | import torch 9 | from lightning_sdk import Teamspace 10 | from lightning_sdk.lightning_cloud.rest_client import GridRestClient 11 | from lightning_sdk.utils.resolve import _resolve_teamspace 12 | 13 | from litmodels import download_model, load_model, save_model, upload_model 14 | from litmodels.integrations.duplicate import duplicate_hf_model 15 | from litmodels.integrations.mixins import PickleRegistryMixin, PyTorchRegistryMixin 16 | from litmodels.io.cloud import _list_available_teamspaces 17 | from litmodels.io.utils import _KERAS_AVAILABLE 18 | from tests.integrations import ( 19 | _SKIP_IF_LIGHTNING_BELLOW_2_5_1, 20 | _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1, 21 | LIT_ORG, 22 | LIT_TEAMSPACE, 23 | ) 24 | 25 | 26 | def _prepare_variables(test_name: str) -> tuple[Teamspace, str, str]: 27 | model_name = f"ci-test_integrations_{test_name}+{os.urandom(8).hex()}" 28 | teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None) 29 | org_team = f"{teamspace.owner.name}/{teamspace.name}" 30 | return teamspace, org_team, model_name 31 | 32 | 33 | def _cleanup_model(teamspace: Teamspace, model_name: str, expected_num_versions: Optional[int] = None) -> None: 34 | """Cleanup model from the teamspace.""" 35 | client = GridRestClient() 36 | # cleaning created models as each test run shall have unique model name 37 | model = client.models_store_get_model_by_name( 38 | project_owner_name=teamspace.owner.name, 39 | project_name=teamspace.name, 40 | model_name=model_name, 41 | ) 42 | if expected_num_versions is not None: 43 | versions = client.models_store_list_model_versions(project_id=model.project_id, model_id=model.id) 44 | assert expected_num_versions == len(versions.versions) 45 | client.models_store_delete_model(project_id=model.project_id, model_id=model.id) 46 | 47 | 48 | @pytest.mark.cloud 49 | @pytest.mark.parametrize( 50 | "in_studio", 51 | [False, pytest.param(True, marks=pytest.mark.skipif(platform.system() != "Linux", reason="Studio is just Linux"))], 52 | ) 53 | def test_upload_download_model(in_studio, monkeypatch, tmp_path): 54 | """Verify that the model is uploaded to the teamspace""" 55 | if in_studio: 56 | # mock env variables as it would run in studio 57 | monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG) 58 | monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE) 59 | 60 | # create a dummy file 61 | file_path = tmp_path / "dummy.txt" 62 | with open(file_path, "w") as f: 63 | f.write("dummy") 64 | 65 | # model name with random hash 66 | teamspace, org_team, model_name = _prepare_variables("upload_download") 67 | model_registry = f"{org_team}/{model_name}" if not in_studio else model_name 68 | 69 | out = StringIO() 70 | with redirect_stdout(out): 71 | upload_model(name=model_registry, model=file_path) 72 | 73 | # validate the output 74 | assert ( 75 | f"Model uploaded successfully. Link to the model: 'https://lightning.ai/{org_team}/models/{model_name}'" 76 | ) in out.getvalue() 77 | 78 | os.remove(file_path) 79 | assert not os.path.isfile(file_path) 80 | 81 | model_files = download_model(name=model_registry, download_dir=tmp_path) 82 | assert model_files == ["dummy.txt"] 83 | for file in model_files: 84 | assert os.path.isfile(os.path.join(tmp_path, file)) 85 | 86 | # CLEANING 87 | _cleanup_model(teamspace, model_name, expected_num_versions=1) 88 | 89 | 90 | @pytest.mark.parametrize( 91 | "importing", 92 | [ 93 | pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_BELLOW_2_5_1), 94 | pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1), 95 | ], 96 | ) 97 | @pytest.mark.parametrize( 98 | "in_studio", 99 | [ 100 | False, 101 | pytest.param(True, marks=pytest.mark.skipif(platform.system() == "Windows", reason="studio is not Windows")), 102 | ], 103 | ) 104 | @pytest.mark.cloud 105 | def test_lightning_default_checkpointing(importing, in_studio, monkeypatch, tmp_path): 106 | if in_studio: 107 | # mock env variables as it would run in studio 108 | monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG) 109 | monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE) 110 | 111 | if importing == "lightning": 112 | from lightning import Trainer 113 | from lightning.pytorch.demos.boring_classes import BoringModel 114 | elif importing == "pytorch_lightning": 115 | from pytorch_lightning import Trainer 116 | from pytorch_lightning.demos.boring_classes import BoringModel 117 | 118 | # model name with random hash 119 | teamspace, org_team, model_name = _prepare_variables("default_checkpoint") 120 | model_registry = f"{org_team}/{model_name}" if not in_studio else model_name 121 | 122 | trainer = Trainer( 123 | max_epochs=2, 124 | default_root_dir=tmp_path, 125 | model_registry=model_registry, 126 | ) 127 | trainer.fit(BoringModel()) 128 | 129 | # CLEANING 130 | _cleanup_model(teamspace, model_name, expected_num_versions=2) 131 | 132 | 133 | @pytest.mark.parametrize("trainer_method", ["fit", "validate", "test", "predict"]) 134 | @pytest.mark.parametrize( 135 | "registry", ["registry", "registry:version:v1", "registry:", "registry::version:v1"] 136 | ) 137 | @pytest.mark.parametrize( 138 | "importing", 139 | [ 140 | pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_BELLOW_2_5_1), 141 | pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1), 142 | ], 143 | ) 144 | @pytest.mark.parametrize( 145 | "in_studio", 146 | [ 147 | False, 148 | pytest.param(True, marks=pytest.mark.skipif(platform.system() == "Windows", reason="studio is not Windows")), 149 | ], 150 | ) 151 | @pytest.mark.cloud 152 | def test_lightning_plain_resume(trainer_method, registry, importing, in_studio, tmp_path, monkeypatch): 153 | if importing == "lightning": 154 | from lightning import Trainer 155 | from lightning.pytorch.demos.boring_classes import BoringModel 156 | elif importing == "pytorch_lightning": 157 | from pytorch_lightning import Trainer 158 | from pytorch_lightning.demos.boring_classes import BoringModel 159 | 160 | if in_studio: 161 | # mock env variables as it would run in studio 162 | monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG) 163 | monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE) 164 | 165 | trainer = Trainer(max_epochs=1, limit_train_batches=50, limit_val_batches=20, default_root_dir=tmp_path) 166 | trainer.fit(BoringModel()) 167 | checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path") 168 | 169 | # model name with random hash 170 | teamspace, org_team, model_name = _prepare_variables(f"resume_{trainer_method}") 171 | model_registry = f"{org_team}/{model_name}" if not in_studio else model_name 172 | upload_model(model=checkpoint_path, name=model_registry) 173 | expected_num_versions = 1 174 | 175 | trainer_kwargs = {"model_registry": model_registry} if "" not in registry else {} 176 | trainer = Trainer( 177 | max_epochs=2, 178 | default_root_dir=tmp_path, 179 | limit_train_batches=10, 180 | limit_val_batches=10, 181 | limit_test_batches=10, 182 | limit_predict_batches=10, 183 | **trainer_kwargs, 184 | ) 185 | registry = registry.replace("", model_registry) 186 | if trainer_method == "fit": 187 | trainer.fit(BoringModel(), ckpt_path=registry) 188 | if trainer_kwargs: 189 | expected_num_versions += 1 190 | elif trainer_method == "validate": 191 | trainer.validate(BoringModel(), ckpt_path=registry) 192 | elif trainer_method == "test": 193 | trainer.test(BoringModel(), ckpt_path=registry) 194 | elif trainer_method == "predict": 195 | trainer.predict(BoringModel(), ckpt_path=registry) 196 | else: 197 | raise ValueError(f"Unknown trainer method: {trainer_method}") 198 | 199 | # CLEANING 200 | _cleanup_model(teamspace, model_name, expected_num_versions=expected_num_versions) 201 | 202 | 203 | @pytest.mark.parametrize( 204 | "importing", 205 | [ 206 | pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_BELLOW_2_5_1), 207 | pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1), 208 | ], 209 | ) 210 | @pytest.mark.cloud 211 | def test_lightning_checkpoint_ddp(importing, tmp_path): 212 | if importing == "lightning": 213 | from lightning import Trainer 214 | from lightning.pytorch.demos.boring_classes import BoringModel 215 | elif importing == "pytorch_lightning": 216 | from pytorch_lightning import Trainer 217 | from pytorch_lightning.demos.boring_classes import BoringModel 218 | 219 | # model name with random hash 220 | teamspace, org_team, model_name = _prepare_variables("checkpoint_resume") 221 | trainer_args = { 222 | "default_root_dir": tmp_path, 223 | "accelerator": "cpu", 224 | "strategy": "ddp_spawn", 225 | "devices": 4, 226 | "model_registry": f"{org_team}/{model_name}", 227 | } 228 | 229 | trainer = Trainer(max_epochs=2, **trainer_args) 230 | trainer.fit(BoringModel()) 231 | 232 | # FIXME: seems like barrier is not respected in the test, but in real life it correctly waits for all GPUs 233 | # trainer = Trainer(max_epochs=5, **trainer_args) 234 | # trainer.fit(BoringModel(), ckpt_path="registry") 235 | 236 | # CLEANING 237 | _cleanup_model(teamspace, model_name, expected_num_versions=2) 238 | 239 | 240 | class DummyModel(PickleRegistryMixin): 241 | def __init__(self, value): 242 | self.value = value 243 | 244 | 245 | @pytest.mark.cloud 246 | def test_pickle_mixin_push_and_pull(): 247 | # model name with random hash 248 | teamspace, org_team, model_name = _prepare_variables("pickle_mixin") 249 | model_registry = f"{org_team}/{model_name}" 250 | 251 | # Create an instance of DummyModel and call push_to_registry. 252 | dummy = DummyModel(42) 253 | dummy.upload_model(model_registry) 254 | 255 | # Call pull_from_registry and load the DummyModel instance. 256 | loaded_dummy = DummyModel.download_model(model_registry) 257 | # Verify that the unpickled instance has the expected value. 258 | assert isinstance(loaded_dummy, DummyModel) 259 | assert loaded_dummy.value == 42 260 | 261 | # CLEANING 262 | _cleanup_model(teamspace, model_name, expected_num_versions=1) 263 | 264 | 265 | # This is a dummy model for PyTorch that uses the PyTorchRegistryMixin. 266 | # This mixin has to be first in the inheritance order. 267 | # Otherwise, `PyTorchRegistryMixin.__init__` need to be called explicitly. 268 | class DummyTorchModel(PyTorchRegistryMixin, torch.nn.Module): 269 | def __init__(self, input_size: int, output_size: int = 10): 270 | # PyTorchRegistryMixin.__init__ will capture these arguments 271 | super().__init__() 272 | self.fc = torch.nn.Linear(input_size, output_size) 273 | 274 | def forward(self, x): 275 | x = x.view(x.size(0), -1) 276 | return self.fc(x) 277 | 278 | 279 | @pytest.mark.cloud 280 | def test_pytorch_mixin_push_and_pull(): 281 | # model name with random hash 282 | teamspace, org_team, model_name = _prepare_variables("torch_mixin") 283 | model_registry = f"{org_team}/{model_name}" 284 | 285 | # Create an instance, push the model and record its forward output. 286 | dummy = DummyTorchModel(784) 287 | dummy.eval() 288 | input_tensor = torch.randn(1, 784) 289 | output_before = dummy(input_tensor) 290 | 291 | dummy.upload_model(model_registry) 292 | 293 | loaded_dummy = DummyTorchModel.download_model(model_registry) 294 | loaded_dummy.eval() 295 | output_after = loaded_dummy(input_tensor) 296 | 297 | assert isinstance(loaded_dummy, DummyTorchModel) 298 | # Compare the outputs as a verification. 299 | assert torch.allclose(output_before, output_after), "Loaded model output differs from original." 300 | 301 | # CLEANING 302 | _cleanup_model(teamspace, model_name, expected_num_versions=1) 303 | 304 | 305 | @pytest.mark.cloud 306 | def test_duplicate_real_hf_model(tmp_path): 307 | """Verify that the HF model can be duplicated to the teamspace""" 308 | 309 | # model name with random hash 310 | model_name = f"litmodels_hf_model+{os.urandom(8).hex()}" 311 | teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None) 312 | org_team = f"{teamspace.owner.name}/{teamspace.name}" 313 | 314 | duplicate_hf_model(hf_model="google/t5-efficient-tiny", lit_model=f"{org_team}/{model_name}") 315 | 316 | client = GridRestClient() 317 | model = client.models_store_get_model_by_name( 318 | project_owner_name=teamspace.owner.name, 319 | project_name=teamspace.name, 320 | model_name=model_name, 321 | ) 322 | client.models_store_delete_model(project_id=teamspace.id, model_id=model.id) 323 | 324 | 325 | @pytest.mark.cloud 326 | def test_list_available_teamspaces(): 327 | teams = _list_available_teamspaces() 328 | assert len(teams) > 0 329 | # using sanitized teamspace name 330 | assert f"{LIT_ORG}/oss-litmodels" in teams 331 | 332 | 333 | @pytest.mark.cloud 334 | @pytest.mark.skipif( 335 | not _KERAS_AVAILABLE, 336 | reason="TensorFlow Keras is not supported on Windows for now.", 337 | ) 338 | def test_save_load_tensorflow_keras(tmp_path): 339 | from tensorflow import keras 340 | 341 | # Define the model 342 | model = keras.Sequential([ 343 | keras.layers.Dense(10, input_shape=(784,), name="dense_1"), 344 | keras.layers.Dense(10, name="dense_2"), 345 | ]) 346 | 347 | # Compile the model 348 | model.compile(optimizer="adam", loss="categorical_crossentropy") 349 | 350 | # model name with random hash 351 | teamspace, org_team, model_name = _prepare_variables("tf-keras") 352 | save_model(f"{org_team}/{model_name}", model=model) 353 | 354 | # Load the model 355 | model_ = load_model(f"{org_team}/{model_name}", download_dir=str(tmp_path)) 356 | 357 | # validate the model 358 | assert isinstance(model_, type(model)) 359 | _cleanup_model(teamspace, model_name, expected_num_versions=1) 360 | -------------------------------------------------------------------------------- /tests/test_io_cloud.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import nullcontext 3 | from unittest import mock 4 | 5 | import joblib 6 | import pytest 7 | import torch 8 | import torch.jit as torch_jit 9 | from sklearn import svm 10 | from torch.nn import Module 11 | 12 | import litmodels 13 | from litmodels import download_model, load_model, save_model 14 | from litmodels.io import upload_model_files 15 | from litmodels.io.utils import _KERAS_AVAILABLE 16 | from tests.integrations import LIT_ORG, LIT_TEAMSPACE 17 | 18 | 19 | @pytest.mark.parametrize("name", ["/too/many/slashes", "org/model", "model-name"]) 20 | @pytest.mark.parametrize("in_studio", [True, False]) 21 | def test_upload_wrong_model_name(name, in_studio, monkeypatch): 22 | if in_studio: 23 | # mock env variables as it would run in studio 24 | monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG) 25 | monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE) 26 | monkeypatch.setattr("lightning_sdk.organization.Organization", mock.MagicMock) 27 | monkeypatch.setattr("lightning_sdk.teamspace.Teamspace", mock.MagicMock) 28 | monkeypatch.setattr("lightning_sdk.teamspace.TeamspaceApi", mock.MagicMock) 29 | monkeypatch.setattr("lightning_sdk.models._get_teamspace", mock.MagicMock) 30 | 31 | in_studio_only_name = in_studio and name == "model-name" 32 | with ( 33 | pytest.raises(ValueError, match=r".*organization/teamspace/model.*") 34 | if not in_studio_only_name 35 | else nullcontext() 36 | ): 37 | upload_model_files(path="path/to/checkpoint", name=name) 38 | 39 | 40 | @pytest.mark.parametrize("name", ["/too/many/slashes", "org/model", "model-name"]) 41 | @pytest.mark.parametrize("in_studio", [True, False]) 42 | def test_download_wrong_model_name(name, in_studio, monkeypatch): 43 | if in_studio: 44 | # mock env variables as it would run in studio 45 | monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG) 46 | monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE) 47 | monkeypatch.setattr("lightning_sdk.organization.Organization", mock.MagicMock) 48 | monkeypatch.setattr("lightning_sdk.teamspace.Teamspace", mock.MagicMock) 49 | monkeypatch.setattr("lightning_sdk.models.TeamspaceApi", mock.MagicMock) 50 | in_studio_only_name = in_studio and name == "model-name" 51 | with ( 52 | pytest.raises(ValueError, match=r".*organization/teamspace/model.*") 53 | if not in_studio_only_name 54 | else nullcontext() 55 | ): 56 | download_model(name=name) 57 | 58 | 59 | @pytest.mark.parametrize( 60 | ("model", "model_path", "verbose"), 61 | [ 62 | # ("path/to/checkpoint", "path/to/checkpoint", False), 63 | # (BoringModel(), "%s/BoringModel.ckpt"), 64 | (torch_jit.script(Module()), f"%s{os.path.sep}RecursiveScriptModule.ts", True), 65 | (Module(), f"%s{os.path.sep}Module.pth", True), 66 | (svm.SVC(), f"%s{os.path.sep}SVC.pkl", 1), 67 | ], 68 | ) 69 | @mock.patch("litmodels.io.cloud.sdk_upload_model") 70 | def test_upload_model(mock_upload_model, tmp_path, model, model_path, verbose): 71 | mock_upload_model.return_value.name = "org-name/teamspace/model-name" 72 | 73 | # The lit-logger function is just a wrapper around the SDK function 74 | save_model( 75 | model=model, 76 | name="org-name/teamspace/model-name", 77 | cloud_account="cluster_id", 78 | staging_dir=str(tmp_path), 79 | verbose=verbose, 80 | ) 81 | expected_path = model_path % str(tmp_path) if "%" in model_path else model_path 82 | mock_upload_model.assert_called_once_with( 83 | path=expected_path, 84 | name="org-name/teamspace/model-name", 85 | cloud_account="cluster_id", 86 | progress_bar=True, 87 | metadata={"litModels": litmodels.__version__, "litModels.integration": "save_model"}, 88 | ) 89 | 90 | 91 | @mock.patch("litmodels.io.cloud.sdk_download_model") 92 | def test_download_model(mock_download_model): 93 | # The lit-logger function is just a wrapper around the SDK function 94 | download_model( 95 | name="org-name/teamspace/model-name", 96 | download_dir="where/to/download", 97 | ) 98 | mock_download_model.assert_called_once_with( 99 | name="org-name/teamspace/model-name", download_dir="where/to/download", progress_bar=True 100 | ) 101 | 102 | 103 | @mock.patch("litmodels.io.cloud.sdk_download_model") 104 | def test_load_model_pickle(mock_download_model, tmp_path): 105 | # create a dummy model file 106 | model_file = tmp_path / "dummy_model.pkl" 107 | test_data = svm.SVC() 108 | joblib.dump(test_data, model_file) 109 | mock_download_model.return_value = [str(model_file.name)] 110 | 111 | # The lit-logger function is just a wrapper around the SDK function 112 | model = load_model( 113 | name="org-name/teamspace/model-name", 114 | download_dir=str(tmp_path), 115 | ) 116 | mock_download_model.assert_called_once_with( 117 | name="org-name/teamspace/model-name", download_dir=str(tmp_path), progress_bar=True 118 | ) 119 | assert isinstance(model, svm.SVC) 120 | 121 | 122 | @mock.patch("litmodels.io.cloud.sdk_download_model") 123 | def test_load_model_torch_jit(mock_download_model, tmp_path): 124 | # create a dummy model file 125 | model_file = tmp_path / "dummy_model.ts" 126 | test_data = torch_jit.script(Module()) 127 | test_data.save(model_file) 128 | mock_download_model.return_value = [str(model_file.name)] 129 | 130 | # The lit-logger function is just a wrapper around the SDK function 131 | model = load_model( 132 | name="org-name/teamspace/model-name", 133 | download_dir=str(tmp_path), 134 | ) 135 | mock_download_model.assert_called_once_with( 136 | name="org-name/teamspace/model-name", download_dir=str(tmp_path), progress_bar=True 137 | ) 138 | assert isinstance(model, torch.jit.ScriptModule) 139 | 140 | 141 | @pytest.mark.skipif(not _KERAS_AVAILABLE, reason="TensorFlow/Keras is not available") 142 | @mock.patch("litmodels.io.cloud.sdk_download_model") 143 | def test_load_model_tf_keras(mock_download_model, tmp_path): 144 | from tensorflow import keras 145 | 146 | # create a dummy model file 147 | model_file = tmp_path / "dummy_model.keras" 148 | # Define the model 149 | model = keras.Sequential([ 150 | keras.layers.Dense(10, input_shape=(784,), name="dense_1"), 151 | keras.layers.Dense(10, name="dense_2"), 152 | ]) 153 | model.compile(optimizer="adam", loss="categorical_crossentropy") 154 | model.save(model_file) 155 | # prepare mocked SDK download function 156 | mock_download_model.return_value = [str(model_file.name)] 157 | 158 | # The lit-logger function is just a wrapper around the SDK function 159 | model = load_model( 160 | name="org-name/teamspace/model-name", 161 | download_dir=str(tmp_path), 162 | ) 163 | mock_download_model.assert_called_once_with( 164 | name="org-name/teamspace/model-name", download_dir=str(tmp_path), progress_bar=True 165 | ) 166 | assert isinstance(model, keras.models.Model) 167 | --------------------------------------------------------------------------------