├── tests ├── __init__.py ├── requirements.txt ├── integration_app │ └── app.py └── test_lightning_integration.py ├── .lightning ├── demo_weights ├── images ├── 0.jpeg ├── 1.jpeg ├── 2.jpeg ├── 3.jpeg ├── 4.jpeg ├── 5.jpeg ├── 6.jpeg ├── 7.jpeg ├── 8.jpeg ├── 9.jpeg ├── 10.jpeg ├── 11.jpeg ├── 12.jpeg ├── 13.jpeg ├── 14.jpeg ├── 15.jpeg ├── 16.jpeg ├── 17.jpeg ├── 18.jpeg └── 19.jpeg ├── .lightningignore ├── quick_start ├── __init__.py ├── __about__.py ├── download.py ├── setup_tools.py └── components.py ├── requirements.txt ├── .github ├── ISSUE_TEMPLATE │ ├── documentation.md │ ├── feature_request.md │ └── bug_report.md ├── workflows │ ├── ci_install-pkg.yml │ └── ci_testing.yml ├── PULL_REQUEST_TEMPLATE.md └── stale.yml ├── MANIFEST.in ├── README.md ├── .pre-commit-config.yaml ├── app.py ├── app_hpo.py ├── setup.py ├── .gitignore └── train_script.py /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | -------------------------------------------------------------------------------- /.lightning: -------------------------------------------------------------------------------- 1 | cluster_id: litng-ai-03 2 | name: quick-start 3 | -------------------------------------------------------------------------------- /demo_weights: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/demo_weights -------------------------------------------------------------------------------- /images/0.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/0.jpeg -------------------------------------------------------------------------------- /images/1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/1.jpeg -------------------------------------------------------------------------------- /images/2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/2.jpeg -------------------------------------------------------------------------------- /images/3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/3.jpeg -------------------------------------------------------------------------------- /images/4.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/4.jpeg -------------------------------------------------------------------------------- /images/5.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/5.jpeg -------------------------------------------------------------------------------- /images/6.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/6.jpeg -------------------------------------------------------------------------------- /images/7.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/7.jpeg -------------------------------------------------------------------------------- /images/8.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/8.jpeg -------------------------------------------------------------------------------- /images/9.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/9.jpeg -------------------------------------------------------------------------------- /images/10.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/10.jpeg -------------------------------------------------------------------------------- /images/11.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/11.jpeg -------------------------------------------------------------------------------- /images/12.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/12.jpeg -------------------------------------------------------------------------------- /images/13.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/13.jpeg -------------------------------------------------------------------------------- /images/14.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/14.jpeg -------------------------------------------------------------------------------- /images/15.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/15.jpeg -------------------------------------------------------------------------------- /images/16.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/16.jpeg -------------------------------------------------------------------------------- /images/17.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/17.jpeg -------------------------------------------------------------------------------- /images/18.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/18.jpeg -------------------------------------------------------------------------------- /images/19.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-quick-start/HEAD/images/19.jpeg -------------------------------------------------------------------------------- /.lightningignore: -------------------------------------------------------------------------------- 1 | lightning_logs 2 | data 3 | data/* 4 | .shared 5 | *.ckpt 6 | *.shared* 7 | *data* 8 | *.git* 9 | *MNIST* 10 | *artifacts* 11 | *egg-info/* 12 | *model_weight.pt* 13 | *lightning-quick-start-quick_start_train* 14 | flagged 15 | *wandb* 16 | .git 17 | venv 18 | .venv 19 | -------------------------------------------------------------------------------- /quick_start/__init__.py: -------------------------------------------------------------------------------- 1 | from quick_start.__about__ import * # noqa: F401, F403 2 | from quick_start.components import ImageServeGradio, PyTorchLightningScript 3 | 4 | __all__ = ["PyTorchLightningScript", "ImageServeGradio"] 5 | 6 | 7 | def exported_lightning_components(): 8 | return [PyTorchLightningScript, ImageServeGradio] 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision <0.20 2 | jsonargparse[signatures] 3 | wandb <0.18.0 4 | gradio <=3.45.2 5 | pyyaml <=6.0.1 6 | protobuf >4.21, <5.28.0 # 4.21 breaks with wandb, tensorboard, or pytorch-lightning: https://github.com/protocolbuffers/protobuf/issues/10048 7 | websockets 8 | lightning[app] >=2.1.0, <2.4.0 9 | tensorboard 10 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Typos and doc fixes 3 | about: Typos and doc fixes 4 | title: '' 5 | labels: documentation 6 | assignees: '' 7 | --- 8 | 9 | ## 📚 Documentation 10 | 11 | For typos and doc fixes, please go ahead and: 12 | 13 | 1. Create an issue. 14 | 1. Fix the typo. 15 | 1. Submit a PR. 16 | 17 | Thanks! 18 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.py 2 | include *.txt 3 | exclude *.lightning 4 | exclude *.gridignore 5 | exclude *.lightningignore 6 | exclude *wandb* 7 | recursive-exclude quick_start *.gridignore 8 | 9 | 10 | recursive-include images *.jpeg 11 | include quick_start/train/demo_weights 12 | include demo_weights 13 | 14 | # exclude tests from package 15 | recursive-exclude tests * 16 | 17 | exclude .pre-commit-config.yaml 18 | -------------------------------------------------------------------------------- /tests/integration_app/app.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from lightning.app import LightningApp, LightningFlow 4 | 5 | from quick_start import PyTorchLightningScript, ImageServeGradio 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class RootFlow(LightningFlow): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def run(self): 15 | print(PyTorchLightningScript) 16 | print(ImageServeGradio) 17 | exit(0) 18 | 19 | 20 | app = LightningApp(RootFlow()) 21 | -------------------------------------------------------------------------------- /tests/test_lightning_integration.py: -------------------------------------------------------------------------------- 1 | from click.testing import CliRunner 2 | from lightning.app.cli.lightning_cli import run_app 3 | 4 | 5 | def test_lightning_can_use_external_component(): 6 | runner = CliRunner() 7 | result = runner.invoke( 8 | run_app, 9 | [ 10 | "tests/integration_app/app.py", 11 | "--blocking", 12 | "False", 13 | "--open-ui", 14 | "False", 15 | ], 16 | catch_exceptions=False, 17 | ) 18 | assert result.exit_code == 0 19 | -------------------------------------------------------------------------------- /quick_start/__about__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.13" 2 | __author__ = "PyTorchLightning et al." 3 | __author_email__ = "name@grid.ai" 4 | __license__ = "TBD" 5 | __copyright__ = f"Copyright (c) 2021-2022, {__author__}." 6 | __homepage__ = "https://github.com/PyTorchLightning/lightning-quick-start" 7 | __docs__ = "Lightning Quick Start" 8 | __long_doc__ = """ 9 | What is it? 10 | ----------- 11 | 12 | """ 13 | 14 | __all__ = [ 15 | "__author__", 16 | "__author_email__", 17 | "__copyright__", 18 | "__docs__", 19 | "__homepage__", 20 | "__license__", 21 | "__version__", 22 | ] 23 | -------------------------------------------------------------------------------- /.github/workflows/ci_install-pkg.yml: -------------------------------------------------------------------------------- 1 | name: Install pkg 2 | 3 | # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | on: # Trigger the workflow on push or pull request, but only for the master branch 5 | push: 6 | branches: 7 | - 'main' 8 | pull_request: 9 | 10 | jobs: 11 | check-package: 12 | uses: Lightning-AI/utilities/.github/workflows/check-package.yml@main 13 | with: 14 | actions-ref: main 15 | artifact-name: dist-packages-${{ github.sha }} 16 | import-name: "quick_start" 17 | testing-matrix: | 18 | { 19 | "os": ["ubuntu-22.04", "macos-13", "windows-2022"], 20 | "python-version": ["3.10"] 21 | } 22 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Before submitting 2 | 3 | - [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements) 4 | - [ ] Did you read the [contributor guideline](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/.github/CONTRIBUTING.md), Pull Request section? 5 | - [ ] Did you make sure to update the docs? 6 | - [ ] Did you write any new necessary tests? 7 | 8 | ## What does this PR do? 9 | 10 | Fixes # (issue). 11 | 12 | ## PR review 13 | 14 | Anyone in the community is free to review the PR once the tests have passed. 15 | If we didn't discuss your PR in Github issues there's a high chance it will not be merged. 16 | 17 | ## Did you have fun? 18 | 19 | Make sure you had fun coding 🙃 20 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # https://github.com/marketplace/stale 2 | 3 | # Number of days of inactivity before an issue becomes stale 4 | daysUntilStale: 60 5 | # Number of days of inactivity before a stale issue is closed 6 | daysUntilClose: 9 7 | # Issues with these labels will never be considered stale 8 | exemptLabels: 9 | - pinned 10 | - security 11 | # Label to use when marking an issue as stale 12 | staleLabel: wontfix 13 | # Comment to post when marking an issue as stale. Set to `false` to disable 14 | markComment: > 15 | This issue has been automatically marked as stale because it has not had 16 | recent activity. It will be closed if no further activity occurs. Thank you 17 | for your contributions. 18 | # Comment to post when closing a stale issue. Set to `false` to disable 19 | closeComment: false 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement, help wanted 6 | assignees: '' 7 | --- 8 | 9 | ## 🚀 Feature 10 | 11 | 12 | 13 | ### Motivation 14 | 15 | 16 | 17 | ### Pitch 18 | 19 | 20 | 21 | ### Alternatives 22 | 23 | 24 | 25 | ### Additional context 26 | 27 | 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lightning Quick Start App 2 | 3 | ### Install Lightning 4 | 5 | ```bash 6 | pip install lightning[app] 7 | ``` 8 | 9 | ### Locally 10 | 11 | In order to run the application locally, run the following commands 12 | 13 | ``` 14 | pip install -r requirements.txt 15 | lightning run app app.py 16 | ``` 17 | 18 | ### Cloud 19 | 20 | In order to run the application cloud, run the following commands 21 | 22 | ### On CPU 23 | 24 | ``` 25 | lightning run app app.py --cloud 26 | ``` 27 | 28 | ### On GPU 29 | 30 | ``` 31 | USE_GPU=1 lightning run app app.py --cloud 32 | ``` 33 | 34 | ### Adding HPO support to Quick Start App. 35 | 36 | Using [Lightning HPO](https://github.com/Lightning-AI/LAI-lightning-hpo-App), you can easily convert the training component into a Sweep Component. 37 | 38 | ```bash 39 | pip install lightning-hpo 40 | lightning run app app_hpo.py 41 | ``` 42 | 43 | ### Learn how it works 44 | 45 | The components are [here](https://github.com/Lightning-AI/lightning-quick-start/blob/main/quick_start/components.py) and the code is heavily commented. 46 | 47 | Once you understand well this example, you aren't a beginner with Lightning App anymore 🔥 48 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | ci: 5 | autofix_prs: true 6 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' 7 | # submodules: true 8 | 9 | repos: 10 | - repo: https://github.com/pre-commit/pre-commit-hooks 11 | rev: v4.6.0 12 | hooks: 13 | - id: end-of-file-fixer 14 | - id: trailing-whitespace 15 | - id: check-case-conflict 16 | - id: check-yaml 17 | exclude: redis/redis.yml 18 | - id: check-toml 19 | - id: check-json 20 | - id: check-added-large-files 21 | - id: check-docstring-first 22 | - id: detect-private-key 23 | 24 | - repo: https://github.com/executablebooks/mdformat 25 | rev: 0.7.17 26 | hooks: 27 | - id: mdformat 28 | additional_dependencies: 29 | - mdformat-gfm 30 | - mdformat-black 31 | - mdformat_frontmatter 32 | 33 | - repo: https://github.com/astral-sh/ruff-pre-commit 34 | rev: v0.5.5 35 | hooks: 36 | - id: ruff 37 | args: ["--fix", "--line-length=120", "--target-version=py39"] 38 | - id: ruff-format 39 | args: ["--line-length=120"] 40 | -------------------------------------------------------------------------------- /.github/workflows/ci_testing.yml: -------------------------------------------------------------------------------- 1 | name: CI testing 2 | 3 | on: # Trigger the workflow on push or pull request, but only for the master branch 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 10 | cancel-in-progress: ${{ github.ref != 'refs/heads/master' }} 11 | 12 | jobs: 13 | pytest: 14 | runs-on: ubuntu-22.04 15 | timeout-minutes: 10 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 3.10 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: "3.10" 22 | 23 | - uses: actions/cache@v4 24 | with: 25 | path: ~/.cache/pip 26 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} 27 | restore-keys: | 28 | ${{ runner.os }}-pip- 29 | 30 | - name: Install packages 31 | run: | 32 | pip install -U -r requirements.txt -r tests/requirements.txt \ 33 | -f https://download.pytorch.org/whl/cpu/torch_stable.html 34 | pip list 35 | 36 | - name: Tests 37 | run: python -m pytest tests 38 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug, help wanted 6 | assignees: '' 7 | --- 8 | 9 | ## 🐛 Bug 10 | 11 | 12 | 13 | ### To Reproduce 14 | 15 | Steps to reproduce the behavior: 16 | 17 | 1. Go to '...' 18 | 1. Run '....' 19 | 1. Scroll down to '....' 20 | 1. See error 21 | 22 | 23 | 24 | #### Code sample 25 | 26 | 28 | 29 | ### Expected behavior 30 | 31 | 32 | 33 | ### Environment 34 | 35 | - PyTorch Version (e.g., 1.0): 36 | - OS (e.g., Linux): 37 | - How you installed PyTorch (`conda`, `pip`, source): 38 | - Build command you used (if compiling from source): 39 | - Python version: 40 | - CUDA/cuDNN version: 41 | - GPU models and configuration: 42 | - Any other relevant information: 43 | 44 | ### Additional context 45 | 46 | 47 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os.path as ops 2 | 3 | from lightning.app import CloudCompute, LightningFlow, LightningApp 4 | 5 | from quick_start.components import ImageServeGradio, PyTorchLightningScript 6 | 7 | 8 | class TrainDeploy(LightningFlow): 9 | def __init__(self): 10 | super().__init__() 11 | self.train_work = PyTorchLightningScript( 12 | script_path=ops.join(ops.dirname(__file__), "./train_script.py"), 13 | script_args=["--trainer.max_epochs=10"], 14 | cloud_compute=CloudCompute("cpu-medium", idle_timeout=60), 15 | ) 16 | 17 | self.serve_work = ImageServeGradio() 18 | 19 | def run(self): 20 | # 1. Run the python script that trains the model 21 | self.train_work.run() 22 | 23 | # 2. when a checkpoint is available, deploy 24 | if self.train_work.best_model_path: 25 | self.serve_work.run(self.train_work.best_model_path) 26 | 27 | def configure_layout(self): 28 | tabs = [] 29 | if not self.train_work.has_stopped: 30 | tabs.append({"name": "Model training", "content": self.train_work}) 31 | tabs.append({"name": "Interactive demo", "content": self.serve_work}) 32 | return tabs 33 | 34 | 35 | app = LightningApp(TrainDeploy()) 36 | -------------------------------------------------------------------------------- /app_hpo.py: -------------------------------------------------------------------------------- 1 | import os.path as ops 2 | 3 | import optuna 4 | from lightning.app import LightningFlow, CloudCompute, LightningApp 5 | from lightning_hpo import BaseObjective, Optimizer 6 | 7 | from quick_start.components import ImageServeGradio, PyTorchLightningScript 8 | 9 | 10 | class HPOPyTorchLightningScript(PyTorchLightningScript, BaseObjective): 11 | @staticmethod 12 | def distributions(): 13 | return {"model.lr": optuna.distributions.LogUniformDistribution(0.0001, 0.1)} 14 | 15 | 16 | class TrainDeploy(LightningFlow): 17 | def __init__(self): 18 | super().__init__() 19 | self.train_work = Optimizer( 20 | script_path=ops.join(ops.dirname(__file__), "./train_script.py"), 21 | script_args=["--trainer.max_epochs=5"], 22 | objective_cls=HPOPyTorchLightningScript, 23 | n_trials=4, 24 | ) 25 | 26 | self.serve_work = ImageServeGradio(CloudCompute("cpu")) 27 | 28 | def run(self): 29 | # 1. Run the python script that trains the model 30 | self.train_work.run() 31 | 32 | # 2. when a checkpoint is available, deploy 33 | if self.train_work.best_model_path: 34 | self.serve_work.run(self.train_work.best_model_path) 35 | 36 | def configure_layout(self): 37 | tab_1 = {"name": "Model training", "content": self.train_work.hi_plot} 38 | tab_2 = {"name": "Interactive demo", "content": self.serve_work} 39 | return [tab_1, tab_2] 40 | 41 | 42 | app = LightningApp(TrainDeploy()) 43 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | from importlib.util import module_from_spec, spec_from_file_location 5 | 6 | from setuptools import find_packages, setup 7 | 8 | _PATH_ROOT = os.path.dirname(__file__) 9 | PACKAGE_NAME = "quick_start" 10 | 11 | 12 | def _load_py_module(fname, pkg=PACKAGE_NAME): 13 | spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_ROOT, pkg, fname)) 14 | py = module_from_spec(spec) 15 | spec.loader.exec_module(py) 16 | return py 17 | 18 | 19 | about = _load_py_module("__about__.py") 20 | setup_tools = _load_py_module("setup_tools.py") 21 | long_description = setup_tools._load_readme_description(_PATH_ROOT, homepage=about.__homepage__, ver=about.__version__) 22 | 23 | # https://packaging.python.org/discussions/install-requires-vs-requirements / 24 | # keep the meta-data here for simplicity in reading this file... it's not obvious 25 | # what happens and to non-engineers they won't know to look in init ... 26 | # the goal of the project is simplicity for researchers, don't want to add too much 27 | # engineer specific practices 28 | 29 | setup( 30 | name=PACKAGE_NAME, 31 | version=about.__version__, 32 | description=about.__docs__, 33 | author=about.__author__, 34 | author_email=about.__author_email__, 35 | url=about.__homepage__, 36 | download_url="https://github.com/PyTorchLightning/lightning-quick-start", 37 | license=about.__license__, 38 | packages=find_packages(exclude=["tests", "docs"]), 39 | long_description=long_description, 40 | long_description_content_type="text/markdown", 41 | include_package_data=True, 42 | zip_safe=False, 43 | keywords=["deep learning", "pytorch", "AI"], 44 | python_requires=">=3.8", 45 | setup_requires=["wheel"], 46 | install_requires=setup_tools._load_requirements(_PATH_ROOT), 47 | ) 48 | -------------------------------------------------------------------------------- /quick_start/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import tarfile 4 | import zipfile 5 | 6 | import requests 7 | 8 | 9 | def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: 10 | """Download file with progressbar. 11 | 12 | # Code taken from: https://gist.github.com/ruxi/5d6803c116ec1130d484a4ab8c00c603 13 | # __author__ = "github.com/ruxi" 14 | # __license__ = "MIT" 15 | 16 | Usage: 17 | download_file('http://web4host.net/5MB.zip') 18 | """ 19 | if url == "NEED_TO_BE_CREATED": 20 | raise NotImplementedError 21 | 22 | if not os.path.exists(path): 23 | os.makedirs(path) 24 | local_filename = os.path.join(path, url.split("/")[-1]) 25 | r = requests.get(url, stream=True, verify=False) 26 | file_size = int(r.headers["Content-Length"]) if "Content-Length" in r.headers else 0 27 | chunk_size = 1024 28 | num_bars = int(file_size / chunk_size) 29 | if verbose: 30 | print(dict(file_size=file_size)) 31 | print(dict(num_bars=num_bars)) 32 | 33 | if not os.path.exists(local_filename): 34 | with open(local_filename, "wb") as fp: 35 | for chunk in r.iter_content(chunk_size=chunk_size): 36 | fp.write(chunk) # type: ignore 37 | 38 | def extract_tarfile(file_path: str, extract_path: str, mode: str): 39 | if os.path.exists(file_path): 40 | with tarfile.open(file_path, mode=mode) as tar_ref: 41 | for member in tar_ref.getmembers(): 42 | try: 43 | tar_ref.extract(member, path=extract_path, set_attrs=False) 44 | except PermissionError: 45 | raise PermissionError(f"Could not extract tar file {file_path}") 46 | 47 | if ".zip" in local_filename: 48 | if os.path.exists(local_filename): 49 | with zipfile.ZipFile(local_filename, "r") as zip_ref: 50 | zip_ref.extractall(path) 51 | elif local_filename.endswith(".tar.gz") or local_filename.endswith(".tgz"): 52 | extract_tarfile(local_filename, path, "r:gz") 53 | elif local_filename.endswith(".tar.bz2") or local_filename.endswith(".tbz"): 54 | extract_tarfile(local_filename, path, "r:bz2") 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Sphinx documentation 58 | docs/_build/ 59 | docs/source/api/ 60 | docs/source/*.md 61 | 62 | # PyBuilder 63 | target/ 64 | 65 | # Jupyter Notebook 66 | .ipynb_checkpoints 67 | 68 | # IPython 69 | profile_default/ 70 | ipython_config.py 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 76 | __pypackages__/ 77 | 78 | # Celery stuff 79 | celerybeat-schedule 80 | celerybeat.pid 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | .dmypy.json 107 | dmypy.json 108 | 109 | # Pyre type checker 110 | .pyre/ 111 | 112 | # PyCharm 113 | .idea/ 114 | 115 | # Lightning logs 116 | lightning_logs 117 | *.gz 118 | .DS_Store 119 | .*_submit.py 120 | .vscode 121 | 122 | MNIST 123 | *.pt 124 | .storage/ 125 | .shared/ 126 | infra 127 | data 128 | coverage.* 129 | # Frontend build artifacts 130 | 131 | lightning_logs 132 | .storage 133 | data 134 | *.pt 135 | __py 136 | wandb 137 | config.yaml 138 | wandb 139 | *.ckpt 140 | *lightning-quick-start-quick_start_train* 141 | flagged 142 | -------------------------------------------------------------------------------- /train_script.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torchvision.transforms as T 5 | from lightning import LightningModule, LightningDataModule 6 | from lightning.pytorch.cli import LightningCLI 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from torchvision.datasets import MNIST 10 | 11 | 12 | class Net(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 16 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 17 | self.dropout1 = nn.Dropout(0.25) 18 | self.dropout2 = nn.Dropout(0.5) 19 | self.fc1 = nn.Linear(9216, 128) 20 | self.fc2 = nn.Linear(128, 10) 21 | 22 | def forward(self, x): 23 | x = self.conv1(x) 24 | x = F.relu(x) 25 | x = self.conv2(x) 26 | x = F.relu(x) 27 | x = F.max_pool2d(x, 2) 28 | x = self.dropout1(x) 29 | x = torch.flatten(x, 1) 30 | x = self.fc1(x) 31 | x = F.relu(x) 32 | x = self.dropout2(x) 33 | x = self.fc2(x) 34 | output = F.log_softmax(x, dim=1) 35 | return output 36 | 37 | 38 | class ImageClassifier(LightningModule): 39 | def __init__(self, model=None, lr=1.0, gamma=0.7, batch_size=32): 40 | super().__init__() 41 | self.save_hyperparameters(ignore="model") 42 | self.model = model or Net() 43 | 44 | checkpoint_path = os.path.join(os.path.dirname(__file__), "demo_weights.pt") 45 | if os.path.exists(checkpoint_path): 46 | self.load_state_dict(torch.load(checkpoint_path).state_dict()) 47 | 48 | @property 49 | def example_input_array(self): 50 | return torch.zeros((1, 1, 28, 28)) 51 | 52 | def forward(self, x): 53 | return self.model(x) 54 | 55 | def training_step(self, batch, batch_idx): 56 | x, y = batch 57 | logits = self.forward(x) 58 | loss = F.nll_loss(logits, y.long()) 59 | self.log("train_loss", loss, on_step=True, on_epoch=True) 60 | return loss 61 | 62 | def validation_step(self, batch, batch_idx): 63 | x, y = batch 64 | logits = self.forward(x) 65 | loss = F.nll_loss(logits, y.long()) 66 | self.log("val_loss", loss) 67 | 68 | def configure_optimizers(self): 69 | return torch.optim.Adadelta(self.model.parameters(), lr=self.hparams.lr) 70 | 71 | 72 | class MNISTDataModule(LightningDataModule): 73 | def __init__(self, batch_size=32): 74 | super().__init__() 75 | self.save_hyperparameters() 76 | 77 | @property 78 | def transform(self): 79 | return T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) 80 | 81 | def prepare_data(self) -> None: 82 | MNIST("./data", download=True) 83 | 84 | def train_dataloader(self): 85 | train_dataset = MNIST("./data", train=True, download=False, transform=self.transform) 86 | return torch.utils.data.DataLoader(train_dataset, batch_size=self.hparams.batch_size) 87 | 88 | def val_dataloader(self): 89 | val_dataset = MNIST("./data", train=False, download=False, transform=self.transform) 90 | return torch.utils.data.DataLoader(val_dataset, batch_size=self.hparams.batch_size) 91 | 92 | 93 | if __name__ == "__main__": 94 | cli = LightningCLI( 95 | ImageClassifier, 96 | MNISTDataModule, 97 | seed_everything_default=42, 98 | save_config_kwargs={"overwrite": True}, 99 | run=False, 100 | ) 101 | cli.trainer.fit(cli.model, datamodule=cli.datamodule) 102 | -------------------------------------------------------------------------------- /quick_start/setup_tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright The PyTorch Lightning team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import os 16 | import re 17 | from typing import List 18 | 19 | _PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) 20 | 21 | 22 | def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comment_char: str = "#") -> List[str]: 23 | """Load requirements from a file. 24 | 25 | >>> _load_requirements(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE 26 | [...] 27 | """ 28 | with open(os.path.join(path_dir, file_name)) as file: 29 | lines = [ln.strip() for ln in file.readlines()] 30 | reqs = [] 31 | for ln in lines: 32 | # filer all comments 33 | if comment_char in ln: 34 | ln = ln[: ln.index(comment_char)].strip() 35 | # skip directly installed dependencies 36 | if ln.startswith("http"): 37 | continue 38 | # skip index url 39 | if ln.startswith("--extra-index-url"): 40 | continue 41 | if ln: # if requirement is not empty 42 | reqs.append(ln) 43 | return reqs 44 | 45 | 46 | def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str: 47 | """Load readme as decribtion.""" 48 | path_readme = os.path.join(path_dir, "README.md") 49 | with open(path_readme, encoding="utf-8") as fp: 50 | text = fp.read() 51 | 52 | # https://github.com/PyTorchLightning/pytorch-lightning/raw/master/docs/source/_images/lightning_module/pt_to_pl.png 53 | github_source_url = os.path.join(homepage, "raw", ver) 54 | # replace relative repository path to absolute link to the release 55 | # do not replace all "docs" as in the readme we reger some other sources with particular path to docs 56 | text = text.replace( 57 | "docs/source/_static/", 58 | f"{os.path.join(github_source_url, 'docs/source/_static/')}", 59 | ) 60 | 61 | # readthedocs badge 62 | text = text.replace("badge/?version=stable", f"badge/?version={ver}") 63 | text = text.replace("lightning.readthedocs.io/en/stable/", f"lightning.readthedocs.io/en/{ver}") 64 | # codecov badge 65 | text = text.replace("/branch/master/graph/badge.svg", f"/release/{ver}/graph/badge.svg") 66 | # replace github badges for release ones 67 | text = text.replace("badge.svg?branch=master&event=push", f"badge.svg?tag={ver}") 68 | 69 | skip_begin = r"" 70 | skip_end = r"" 71 | # todo: wrap content as commented description 72 | text = re.sub( 73 | rf"{skip_begin}.+?{skip_end}", 74 | "", 75 | text, 76 | flags=re.IGNORECASE + re.DOTALL, 77 | ) 78 | 79 | # # https://github.com/Borda/pytorch-lightning/releases/download/1.1.0a6/codecov_badge.png 80 | # github_release_url = os.path.join(homepage, "releases", "download", ver) 81 | # # download badge and replace url with local file 82 | # text = _parse_for_badge(text, github_release_url) 83 | return text 84 | -------------------------------------------------------------------------------- /quick_start/components.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.simplefilter("ignore") 4 | import logging # noqa: E402 5 | import os # noqa: E402 6 | from functools import partial # noqa: E402 7 | from subprocess import Popen # noqa: E402 8 | 9 | import gradio # noqa: E402 10 | import torch # noqa: E402 11 | import torchvision.transforms as T # noqa: E402 12 | from lightning.app.components.python import TracerPythonScript # noqa: E402 13 | from lightning.app.components.serve import ServeGradio # noqa: E402 14 | from lightning.app.storage import Path # noqa: E402 15 | 16 | from quick_start.download import download_data # noqa: E402 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class PyTorchLightningScript(TracerPythonScript): 22 | """This component executes a PyTorch Lightning script and injects a callback in the Trainer at runtime in order to 23 | start tensorboard server.""" 24 | 25 | def __init__(self, *args, **kwargs): 26 | super().__init__(*args, **kwargs) 27 | # 1. Keep track of the best model path. 28 | self.best_model_path = None 29 | self.best_model_score = None 30 | 31 | def configure_tracer(self): 32 | # 1. Override `configure_tracer`` 33 | 34 | # 2. Import objects from lightning.pytorch 35 | from lightning.pytorch import Trainer 36 | from lightning.pytorch.callbacks import Callback 37 | 38 | # 3. Create a tracer. 39 | tracer = super().configure_tracer() 40 | 41 | # 4. Implement a callback to launch tensorboard server. 42 | class TensorboardServerLauncher(Callback): 43 | def __init__(self, work): 44 | # The provided `work` is the current ``PyTorchLightningScript`` work. 45 | self._work = work 46 | 47 | def on_train_start(self, trainer, *_): 48 | # Provide `host` and `port` in order for tensorboard to be usable in the cloud. 49 | self._work._process = Popen( 50 | f"tensorboard --logdir='{trainer.logger.log_dir}'" 51 | f" --host {self._work.host} --port {self._work.port}", 52 | shell=True, 53 | ) 54 | 55 | def trainer_pre_fn(self, *args, work=None, **kwargs): 56 | # Intercept Trainer __init__ call and inject a ``TensorboardServerLauncher`` component. 57 | kwargs["callbacks"].append(TensorboardServerLauncher(work)) 58 | return {}, args, kwargs 59 | 60 | # 5. Patch the `__init__` method of the Trainer to inject our callback with a reference to the work. 61 | tracer.add_traced(Trainer, "__init__", pre_fn=partial(trainer_pre_fn, work=self)) 62 | return tracer 63 | 64 | def run(self, *args, **kwargs): 65 | ######### [DEMO PURPOSE] ######### 66 | 67 | # 1. Download a pre-trained model for speed reason. 68 | download_data( 69 | "https://pl-flash-data.s3.amazonaws.com/assets_lightning/demo_weights.pt", 70 | "./", 71 | ) 72 | 73 | # 2. Add some arguments to the Trainer to make training faster. 74 | self.script_args += [ 75 | "--trainer.limit_train_batches=12", 76 | "--trainer.limit_val_batches=4", 77 | "--trainer.callbacks=ModelCheckpoint", 78 | "--trainer.callbacks.monitor=val_loss", 79 | ] 80 | 81 | # 3. Utilities 82 | warnings.simplefilter("ignore") 83 | logger.info(f"Running train_script: {self.script_path}") 84 | ######### [DEMO PURPOSE] ######### 85 | 86 | logger.info(f"Running train_script: {self.script_path}") 87 | 88 | # 4. Execute the parent run method 89 | super().run(*args, **kwargs) 90 | 91 | def on_after_run(self, script_globals): 92 | # 1. Once the script has finished to execute, we can collect its globals and access any objects. 93 | # Here, we are accessing the LightningCLI and the associated lightning_module 94 | lightning_module = script_globals["cli"].trainer.lightning_module 95 | 96 | # 2. From the checkpoint_callback, we are accessing the best model weights 97 | checkpoint = torch.load(script_globals["cli"].trainer.checkpoint_callback.best_model_path) 98 | 99 | # 3. Load the best weights and torchscript the model. 100 | lightning_module.load_state_dict(checkpoint["state_dict"]) 101 | lightning_module.to_torchscript("model_weight.pt") 102 | 103 | # 4. Use lightning.app.storage.Path to create a reference to the torchscripted model 104 | # When running in the cloud on multiple machines, by simply passing this reference to another work, 105 | # it triggers automatically a transfer. 106 | self.best_model_path = Path("model_weight.pt") 107 | 108 | # 5. Keep track of the metrics. 109 | self.best_model_score = float(script_globals["cli"].trainer.checkpoint_callback.best_model_score) 110 | 111 | 112 | class ImageServeGradio(ServeGradio): 113 | inputs = gradio.inputs.Image(type="pil", shape=(28, 28)) 114 | outputs = gradio.outputs.Label(num_top_classes=10) 115 | 116 | def __init__(self, *args, **kwargs): 117 | super().__init__(*args, **kwargs) 118 | self.examples = None 119 | self.best_model_path = None 120 | self._transform = None 121 | self._labels = {idx: str(idx) for idx in range(10)} 122 | 123 | def run(self, best_model_path): 124 | ######### [DEMO PURPOSE] ######### 125 | # Download some examples so it works locally and in the cloud (issue with gradio on loading the images.) 126 | download_data( 127 | "https://pl-flash-data.s3.amazonaws.com/assets_lightning/images.tar.gz", 128 | "./", 129 | ) 130 | self.examples = [os.path.join("./images", f) for f in os.listdir("./images")] 131 | ######### [DEMO PURPOSE] ######### 132 | 133 | self.best_model_path = best_model_path 134 | self._transform = T.Compose([T.Resize((28, 28)), T.ToTensor()]) 135 | super().run() 136 | 137 | def predict(self, img): 138 | with torch.inference_mode(): 139 | # 1. Receive an image and transform it into a tensor 140 | img = self._transform(img)[0] 141 | img = img.unsqueeze(0).unsqueeze(0) 142 | 143 | # 2. Apply the model on the image and convert the logits into probabilities 144 | prediction = torch.exp(self.model(img)) 145 | 146 | # 3. Return the data in the `gr.outputs.Label` format 147 | return {self._labels[i]: prediction[0][i].item() for i in range(10)} 148 | 149 | def build_model(self): 150 | # 1. Load the best model. As torchscripted by the first component, using torch.load works out of the box. 151 | model = torch.load(self.best_model_path) 152 | 153 | # 2. Prepare the model for predictions. 154 | for p in model.parameters(): 155 | p.requires_grad = False 156 | model.eval() 157 | 158 | # 3. Return the model. 159 | return model 160 | --------------------------------------------------------------------------------