├── .envrc ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── configs ├── mnist.yaml ├── mrpc.yaml ├── presets │ ├── ddp.yaml │ ├── debugger.yaml │ ├── default.yaml │ ├── limiter.yaml │ ├── overfitter.yaml │ ├── profiler.yaml │ └── tester.yaml └── sweep_mnist.yaml ├── data └── .gitkeep ├── environment.yaml ├── models └── .gitkeep ├── notebooks └── .gitkeep ├── pyproject.toml ├── requirements.txt ├── results └── .gitkeep ├── run ├── scripts ├── .gitkeep ├── print_results ├── run.sh ├── sweep └── sweep_mnist.sh └── src ├── __init__.py ├── callbacks ├── __init__.py └── metric.py ├── datamodules ├── __init__.py ├── datasets │ └── __init__.py ├── glue_datamodule.py └── mnist_datamodule.py ├── models ├── __init__.py ├── glue_transformer.py ├── mnist_model.py └── modules │ └── __init__.py ├── utils ├── __init__.py ├── lit_cli.py ├── loggers.py └── sweep_cli.py └── vendor └── __init__.py /.envrc: -------------------------------------------------------------------------------- 1 | layout conda lit-template 2 | export PATH=$PWD:$PWD/scripts:$PATH 3 | export PYTHONPATH=$PWD${PYTHONPATH:+":$PYTHONPATH"} 4 | [ -f .env ] && dotenv 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python ### 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | ### Lightning-Template ### 133 | /data/* 134 | !/data/.gitkeep 135 | /models/* 136 | !/models/.gitkeep 137 | /results/* 138 | !/results/.gitkeep 139 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v5.0.0 6 | hooks: 7 | - id: check-executables-have-shebangs 8 | - id: check-json 9 | - id: check-shebang-scripts-are-executable 10 | - id: check-yaml 11 | - id: detect-private-key 12 | - id: end-of-file-fixer 13 | - id: trailing-whitespace 14 | 15 | - repo: https://github.com/charliermarsh/ruff-pre-commit 16 | rev: v0.11.8 17 | hooks: 18 | - id: ruff 19 | - id: ruff-format 20 | 21 | - repo: https://github.com/kynan/nbstripout.git 22 | rev: 0.8.1 23 | hooks: 24 | - id: nbstripout 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Tianshu Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lightning-Template 2 | 3 | [![python](https://img.shields.io/badge/-Python_3.10_%7C_3.11_%7C_3.12-blue?logo=python&logoColor=white&style=flat-square)](https://github.com/tshu-w/lightning-template) 4 | [![pytorch](https://img.shields.io/badge/PyTorch_2.4+-ee4c2c?logo=pytorch&logoColor=white&style=flat-square)](https://pytorch.org) 5 | [![lightning](https://img.shields.io/badge/Lightning_2.4+-792ee5?logo=pytorchlightning&logoColor=white&style=flat-square)](https://lightning.ai) 6 | [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json&style=flat-square)](https://github.com/astral-sh/ruff) 7 | [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray&style=flat-square)](https://github.com/tshu-w/lightning-template?tab=MIT-1-ov-file) 8 | 9 | A clean and flexible [Pytorch Lightning](https://github.com/Lightning-AI/pytorch-lightning) template to kickstart and structure your deep learning project, ensuring efficient workflow, reproducibility, and easy extensibility for rapid experiments. 10 | 11 | ### Why Lightning-Template? 12 | 13 | Pytorch Lightning is a deep learning framework designed for professional AI researchers and engineers, freeing users from boilerplate code (_e.g._, multiple GPUs/TPUs/HPUs training, early stopping, and checkpointing) to focus on going from idea to paper/production. 14 | 15 | This Lightning template leverages [Lightning CLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html) to separate configuration from source code, guaranteeing reproducibility of experiments, and incorporates many other [best practices](#best-practices). 16 | 17 | + **Compared to [Lightning-Hydra-Template](https://github.com/ashleve/lightning-hydra-template)**: Our template provides similar functionality through a simple and straightforward encapsulation of Lightning's built-in CLI, making it suitable for users who prefer minimal setup without an additional Hybra layer. 18 | 19 | > Note: This is an unofficial project that lacks comprehensive test and continuous integration. 20 | 21 | ### Quickstart 22 | 23 | ```console 24 | git clone https://github.com/YourGithubName/your-repository-name 25 | cd your-repository-name 26 | 27 | # [SUGGESTED] use conda environment 28 | conda env create -n env-name -f environment.yaml 29 | conda activate env-name 30 | 31 | # [ALTERNATIVE] install requirements directly 32 | pip install -r requirements.txt 33 | 34 | # Run the sample script, i.e., ./run fit --config configs/mnist.yaml 35 | bash -x scripts/run.sh 36 | ``` 37 | 38 | ### Workflow - how it works 39 | 40 | Before using this template, please read the basic Pytorch Lightning documentation: [Lightning in 15 minutes](https://lightning.ai/docs/pytorch/stable/starter/introduction.html). 41 | 42 | 1. Define a [Lightning Module](https://lightning.ai/docs/pytorch/2.4.0/common/lightning_module.html) (Examples: [mnist_model.py](src/models/mnist_model.py) and [glue_transformer.py](src/models/glue_transformer.py)) 43 | 2. Define a [Lightning DataModule](https://lightning.ai/docs/pytorch/2.4.0/data/datamodule.html#lightningdatamodule) (Examples: [mnist_datamodule.py](src/datamodules/mnist_datamodule.py) and [glue_datamodule.py](src/datamodules/glue_datamodule.py)) 44 | 3. Prepare your experiment configs (Examples: [mnist.yaml](configs/mnist.yaml) and [mrpc.yaml](configs/mrpc.yaml)) 45 | 4. Run experiments (_cf._, [Configure hyperparameters from the CLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html)) 46 | - To see the available commands type: 47 | ```console 48 | ./run --help 49 | ``` 50 | - Train a model from the config: 51 | ```console 52 | ./run fit --config configs/mnist.yaml 53 | ``` 54 | - Override config options: 55 | ```console 56 | ./run fit --config configs/mnist.yaml --trainer.precision 16 --model.learning_rate 0.1 --data.batch_size 64 57 | ``` 58 | - Separate model and datamodule configs: 59 | ```console 60 | ./run fit --config configs/data.yaml --config configs/model.yaml 61 | ``` 62 | 63 | ### Project Structure 64 | The directory structure of a project looks like this: 65 | ``` 66 | lightning-template 67 | ├── configs ← Directory of Configs 68 | │   ├── mnist.yaml 69 | │   ├── mrpc.yaml 70 | │   ├── presets ← Preset configs for Lightning features 71 | │   └── sweep_mnist.yaml 72 | ├── data ← Directory of Data 73 | ├── environment.yaml 74 | ├── models ← Directory of Models 75 | ├── notebooks ← Directory of Notebooks 76 | ├── pyproject.toml 77 | ├── README.md 78 | ├── requirements.txt 79 | ├── results ← Directory of Results 80 | ├── run ← Script to Run Lightning CLI 81 | ├── scripts ← Directory of Scripts 82 | │   ├── print_results 83 | │   ├── run.sh 84 | │   ├── sweep ← Script to sweep Experiments 85 | │   └── sweep_mnist.sh 86 | └── src ← Directory of Source Code 87 | ├── callbacks 88 | ├── datamodules 89 | ├── models 90 | ├── utils 91 | └── vendor ← Directory of Third-Party Code 92 | ``` 93 | 94 | ### Best Practices 95 | 1. Use [conda](https://docs.anaconda.com/miniconda/) to manage environments. 96 | 2. Leverages Lightning awesome features (_cf._, [How-to Guides](https://lightning.ai/docs/pytorch/stable/common/) & [Glossary](https://lightning.ai/docs/pytorch/stable/glossary/)) 97 | 3. Use [pre-commit](https://pre-commit.com) and [ruff](https://docs.astral.sh/ruff) to check and format code with configuration in [pyproject.toml](pyproject.toml) and [.pre-commit-config.yaml](.pre-commit-config.yaml). 98 | ```console 99 | pre-commit install 100 | ``` 101 | 4. Use [dotenv](https://github.com/motdotla/dotenv) to automatically change environments and set variables (_cf._, [.envrc](.envrc)). 102 | ```console 103 | λ cd lightning-template 104 | direnv: loading ~/lightning-template/.envrc 105 | direnv: export +CONDA_DEFAULT_ENV +CONDA_EXE +CONDA_PREFIX +CONDA_PROMPT_MODIFIER +CONDA_PYTHON_EXE +CONDA_SHLVL +_CE_CONDA +_CE_M ~PATH ~PYTHONPATH 106 | ``` 107 | 1. Add the project root to `PATH` to use `run` script directly. 108 | ```console 109 | export PATH=$PWD:$PWD/scripts:$PATH 110 | run fit --config configs/mnist.yaml 111 | ``` 112 | 2. Add the project root to `PYTHONPATH` to avoid modifying `sys.path` in scripts. 113 | ```console 114 | export PYTHONPATH=$PWD${PYTHONPATH:+":$PYTHONPATH"} 115 | ``` 116 | 3. Save privacy variable to `.env`. 117 | 5. Use [shtab](https://jsonargparse.readthedocs.io/en/stable/#tab-completion) to generate shell completion file. 118 | Screenshot 2024-08-16 at 22 57 14 119 | 6. Use [ray tune](https://docs.ray.io/en/latest/tune/index.html) to sweep parameters or hyperparameter search (_cf._, [sweep_cli.py](src/utils/sweep_cli.py)). 120 | ```console 121 | bash ./scripts/sweep --config configs/sweep_mnist.yaml 122 | ``` 123 | 7. Use third-party logger (_e.g._, [w&b](https://wandb.ai) and [aim](https://aimstack.io)) to track experiments. 124 | 125 | ### DELETE EVERYTHING ABOVE FOR YOUR PROJECT 126 | 127 | --- 128 | 129 |
130 | 131 |

Your Project Name

132 | 133 |

134 | Arxiv 135 | Conference 136 |

137 | 138 |
139 | 140 | ## Description 141 | What it does 142 | 143 | ## How to run 144 | First, install dependencies 145 | ```console 146 | # clone project 147 | git clone https://github.com/YourGithubName/your-repository-name 148 | cd your-repository-name 149 | 150 | # [SUGGESTED] use conda environment 151 | conda env create -f environment.yaml 152 | conda activate lit-template 153 | 154 | # [ALTERNATIVE] install requirements directly 155 | pip install -r requirements.txt 156 | ``` 157 | 158 | Next, to obtain the main results of the paper: 159 | ```console 160 | # commands to get the main results 161 | ``` 162 | 163 | You can also run experiments with the `run` script. 164 | ```console 165 | # fit with the demo config 166 | ./run fit --config configs/demo.yaml 167 | # or specific command line arguments 168 | ./run fit --model MNISTModel --data MNISTDataModule --data.batch_size 32 --trainer.gpus 0 169 | 170 | # evaluate with the checkpoint 171 | ./run test --config configs/demo.yaml --ckpt_path ckpt_path 172 | 173 | # get the script help 174 | ./run --help 175 | ./run fit --help 176 | ``` 177 | 178 | ## Citation 179 | ``` 180 | @article{YourName, 181 | title={Your Title}, 182 | author={Your team}, 183 | journal={Location}, 184 | year={Year} 185 | } 186 | ``` 187 | -------------------------------------------------------------------------------- /configs/mnist.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 123 2 | trainer: 3 | max_epochs: 20 4 | model: 5 | class_path: MNISTModel 6 | data: 7 | class_path: MNISTDataModule 8 | -------------------------------------------------------------------------------- /configs/mrpc.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 123 2 | trainer: 3 | max_epochs: 30 4 | model: 5 | class_path: GLUETransformer 6 | init_args: 7 | model_name_or_path: bert-base-uncased 8 | max_length: 256 9 | data: 10 | class_path: GLUEDataModule 11 | init_args: 12 | task_name: mrpc 13 | batch_size: 32 14 | num_workers: 0 15 | pin_memory: true 16 | -------------------------------------------------------------------------------- /configs/presets/ddp.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | devices: 4 3 | # strategy: ddp 4 | strategy: ddp_find_unused_parameters_false 5 | -------------------------------------------------------------------------------- /configs/presets/debugger.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | accelerator: "cpu" 3 | detect_anomaly: true 4 | fast_dev_run: 1 5 | -------------------------------------------------------------------------------- /configs/presets/default.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | trainer: 3 | default_root_dir: results 4 | callbacks: 5 | class_path: Metric 6 | logger: 7 | class_path: CSVLogger 8 | init_args: 9 | save_dir: results 10 | accelerator: auto 11 | devices: 1 12 | -------------------------------------------------------------------------------- /configs/presets/limiter.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | max_epochs: 5 3 | limit_train_batches: 10 4 | limit_val_batches: 5 5 | limit_test_batches: 5 6 | logger: false 7 | -------------------------------------------------------------------------------- /configs/presets/overfitter.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | overfit_batches: 0.01 3 | logger: false 4 | -------------------------------------------------------------------------------- /configs/presets/profiler.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | max_epochs: 1 3 | profiler: "simple" 4 | # profiler: "advanced" 5 | # profiler: "pytorch" 6 | -------------------------------------------------------------------------------- /configs/presets/tester.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | fast_dev_run: 5 3 | logger: false 4 | -------------------------------------------------------------------------------- /configs/sweep_mnist.yaml: -------------------------------------------------------------------------------- 1 | fit: 2 | debug: false 3 | gpus_per_trial: 1 4 | configs: 5 | - configs/mnist.yaml 6 | override_kwargs: 7 | seed_everything: 8 | - 123 9 | - 42 10 | trainer.max_epochs: 5 11 | data.batch_size: 12 | - 64 13 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tshu-w/lightning-template/810508e9e345dde2f377bd3e8d6471ce585736c0/data/.gitkeep -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: lit-template 2 | channels: 3 | - defaults 4 | - conda-forge 5 | - pytorch 6 | dependencies: 7 | - python=3.12 8 | - pip 9 | - pip: 10 | - --requirement requirements.txt 11 | -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tshu-w/lightning-template/810508e9e345dde2f377bd3e8d6471ce585736c0/models/.gitkeep -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tshu-w/lightning-template/810508e9e345dde2f377bd3e8d6471ce585736c0/notebooks/.gitkeep -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # https://github.com/microsoft/pyright 2 | [tool.pyright] 3 | include = ["src"] 4 | venv = "lit-template" 5 | typeCheckingMode = "off" 6 | useLibraryCodeForTypes = true 7 | 8 | # https://github.com/charliermarsh/ruff 9 | [tool.ruff] 10 | fix = true 11 | line-length = 88 12 | target-version = "py310" 13 | [tool.ruff.lint] 14 | select = [ 15 | "E", # pycodestyle 16 | "F", # Pyflakes 17 | "UP", # pyupgrade 18 | "B", # flake8-bugbear 19 | "SIM", # flake8-simplify 20 | "I", # isort 21 | ] 22 | ignore = ["E501"] 23 | # https://github.com/timothycrosley/isort/ 24 | [tool.ruff.lint.isort] 25 | combine-as-imports = true 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch >= 2.4.0 2 | lightning >= 2.4.0 3 | torchvision 4 | jsonargparse[signatures] # for CLI 5 | ray[tune] 6 | 7 | transformers 8 | scikit-learn 9 | datasets 10 | evaluate 11 | 12 | # dev tools 13 | jupyterlab 14 | shtab 15 | -------------------------------------------------------------------------------- /results/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tshu-w/lightning-template/810508e9e345dde2f377bd3e8d6471ce585736c0/results/.gitkeep -------------------------------------------------------------------------------- /run: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname "$0")" 3 | 4 | python -m src.utils.lit_cli "$@" 5 | -------------------------------------------------------------------------------- /scripts/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tshu-w/lightning-template/810508e9e345dde2f377bd3e8d6471ce585736c0/scripts/.gitkeep -------------------------------------------------------------------------------- /scripts/print_results: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Modified from: https://github.com/allenai/allennlp/blob/main/allennlp/commands/print_results.py 3 | 4 | import argparse 5 | import json 6 | from pathlib import Path 7 | from signal import SIG_DFL, SIGPIPE, signal 8 | 9 | signal(SIGPIPE, SIG_DFL) 10 | 11 | 12 | def main(args: argparse.Namespace): 13 | """ 14 | Prints results from an `argparse.Namespace` object. 15 | """ 16 | path = args.path 17 | metrics_name = args.metrics_filename 18 | keys = args.keys 19 | 20 | results_dict = {} 21 | for f in path.rglob(metrics_name): 22 | with open(f) as file_: 23 | metrics = json.load(file_) 24 | name = f.parents[0].relative_to(f.parents[2]) 25 | results_dict[name] = metrics 26 | 27 | sorted_keys = sorted(list(results_dict.keys())) 28 | print(f"{path.name}, {', '.join(keys)}") 29 | for name in sorted_keys: 30 | results = results_dict[name] 31 | keys_to_print = (str(results.get(key, "N/A")) for key in keys) 32 | print(f"{name}, {', '.join(keys_to_print)}") 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser( 37 | description="Print experiment results in a helpful CSV format." 38 | ) 39 | 40 | parser.add_argument( 41 | "path", 42 | type=Path, 43 | help="Path to recursively search for experiment directories.", 44 | ) 45 | parser.add_argument( 46 | "-k", 47 | "--keys", 48 | type=str, 49 | nargs="+", 50 | help='Keys to print from metrics.json. Keys not present in all metrics.json will result in "N/A"', 51 | default=[], 52 | required=False, 53 | ) 54 | parser.add_argument( 55 | "-m", 56 | "--metrics-filename", 57 | type=str, 58 | help="Name of the metrics file to inspect.", 59 | default="metrics.json", 60 | required=False, 61 | ) 62 | 63 | args = parser.parse_args() 64 | main(args) 65 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | ./run fit --config configs/mnist.yaml 2 | -------------------------------------------------------------------------------- /scripts/sweep: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd "$(dirname $(dirname "$0"))" 3 | 4 | python -m src.utils.sweep_cli "$@" 5 | -------------------------------------------------------------------------------- /scripts/sweep_mnist.sh: -------------------------------------------------------------------------------- 1 | bash ./scripts/sweep --config configs/sweep_mnist.yaml 2 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from . import callbacks, datamodules, models 2 | from .utils import loggers 3 | 4 | __all__ = ["callbacks", "datamodules", "models", "loggers"] 5 | -------------------------------------------------------------------------------- /src/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .metric import Metric 2 | 3 | __all__ = ["Metric"] 4 | -------------------------------------------------------------------------------- /src/callbacks/metric.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from pathlib import Path 4 | 5 | import lightning as L 6 | from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars 7 | from lightning.pytorch.trainer.states import TrainerFn 8 | 9 | 10 | class Metric(L.Callback): 11 | r""" 12 | Save logged metrics to ``Trainer.log_dir``. 13 | """ 14 | 15 | def teardown( 16 | self, 17 | trainer: L.Trainer, 18 | pl_module: L.LightningModule, 19 | stage: str | None = None, 20 | ) -> None: 21 | metrics = {} 22 | if stage == TrainerFn.FITTING: 23 | if ( 24 | trainer.checkpoint_callback 25 | and trainer.checkpoint_callback.best_model_path 26 | ): 27 | ckpt_path = trainer.checkpoint_callback.best_model_path 28 | # inhibit disturbing logging 29 | logging.getLogger("lightning.pytorch.utilities.distributed").setLevel( 30 | logging.WARNING 31 | ) 32 | logging.getLogger("lightning.pytorch.accelerators.gpu").setLevel( 33 | logging.WARNING 34 | ) 35 | 36 | fn_kwargs = { 37 | "model": pl_module, 38 | "datamodule": trainer.datamodule, 39 | "ckpt_path": ckpt_path, 40 | } 41 | 42 | val_metrics = {} 43 | if trainer.validate_loop._data_source.is_defined(): 44 | trainer.validate(**fn_kwargs) 45 | val_metrics = convert_tensors_to_scalars(trainer.logged_metrics) 46 | 47 | test_metrics = {} 48 | if trainer.test_loop._data_source.is_defined(): 49 | trainer.test(**fn_kwargs) 50 | test_metrics = convert_tensors_to_scalars(trainer.logged_metrics) 51 | 52 | metrics = {**val_metrics, **test_metrics} 53 | else: 54 | metrics = convert_tensors_to_scalars(trainer.logged_metrics) 55 | 56 | if metrics: 57 | metrics_str = json.dumps(metrics, ensure_ascii=False, indent=2) 58 | 59 | metrics_file = Path(trainer.log_dir) / "metrics.json" 60 | with metrics_file.open("w") as f: 61 | f.write(metrics_str) 62 | -------------------------------------------------------------------------------- /src/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .glue_datamodule import GLUEDataModule 2 | from .mnist_datamodule import MNISTDataModule 3 | 4 | __all__ = ["GLUEDataModule", "MNISTDataModule"] 5 | -------------------------------------------------------------------------------- /src/datamodules/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tshu-w/lightning-template/810508e9e345dde2f377bd3e8d6471ce585736c0/src/datamodules/datasets/__init__.py -------------------------------------------------------------------------------- /src/datamodules/glue_datamodule.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from functools import partial 3 | from typing import Literal 4 | 5 | import lightning as L 6 | from datasets import load_dataset 7 | from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS 8 | from torch.utils.data import DataLoader 9 | 10 | warnings.filterwarnings( 11 | "ignore", ".*Consider increasing the value of the `num_workers` argument*" 12 | ) 13 | 14 | TASK_NAME = Literal[ 15 | "cola", 16 | "sst2", 17 | "mrpc", 18 | "qqp", 19 | "stsb", 20 | "mnli", 21 | "qnli", 22 | "rte", 23 | "wnli", 24 | "ax", 25 | ] 26 | 27 | 28 | class GLUEDataModule(L.LightningDataModule): 29 | task_text_field_map = { 30 | "cola": ["sentence"], 31 | "sst2": ["sentence"], 32 | "mrpc": ["sentence1", "sentence2"], 33 | "qqp": ["question1", "question2"], 34 | "stsb": ["sentence1", "sentence2"], 35 | "mnli": ["premise", "hypothesis"], 36 | "qnli": ["question", "sentence"], 37 | "rte": ["sentence1", "sentence2"], 38 | "wnli": ["sentence1", "sentence2"], 39 | "ax": ["premise", "hypothesis"], 40 | } 41 | 42 | glue_task_num_labels = { 43 | "cola": 2, 44 | "sst2": 2, 45 | "mrpc": 2, 46 | "qqp": 2, 47 | "stsb": 1, 48 | "mnli": 3, 49 | "qnli": 2, 50 | "rte": 2, 51 | "wnli": 2, 52 | "ax": 3, 53 | } 54 | 55 | def __init__( 56 | self, 57 | task_name: TASK_NAME = "mrpc", 58 | batch_size: int = 32, 59 | num_workers: int = 0, 60 | pin_memory: bool = False, 61 | ): 62 | super().__init__() 63 | self.save_hyperparameters() 64 | 65 | self.task_name = task_name 66 | self.num_labels = self.glue_task_num_labels[task_name] 67 | self.text_fields = self.task_text_field_map[task_name] 68 | 69 | def prepare_data(self) -> None: 70 | # setup first to prevent datasets cache conflicts in multiple processes. 71 | self.setup() 72 | 73 | def setup(self, stage: str | None = None) -> None: 74 | if not hasattr(self, "datasets"): 75 | convert_to_features = self.trainer.model.convert_to_features 76 | preprocess_fn = partial(self._preprocess, text_fields=self.text_fields) 77 | 78 | def preprocess(x): 79 | return convert_to_features(preprocess_fn(x)) 80 | 81 | datasets = load_dataset("glue", self.task_name) 82 | columns_names = self.text_fields + ["label", "idx"] 83 | self.datasets = datasets.map( 84 | preprocess, 85 | batched=True, 86 | remove_columns=columns_names, 87 | ) 88 | 89 | self.datasets.set_format(type="torch") 90 | 91 | self.val_splits = [x for x in self.datasets if "validation" in x] 92 | self.test_splits = [x for x in self.datasets if "test" in x] 93 | 94 | self.collate_fn = getattr(self.trainer.model, "collate_fn", None) 95 | 96 | def train_dataloader(self) -> TRAIN_DATALOADERS: 97 | return DataLoader( 98 | dataset=self.datasets["train"], 99 | batch_size=self.hparams.batch_size, 100 | num_workers=self.hparams.num_workers, 101 | pin_memory=self.hparams.pin_memory, 102 | collate_fn=self.collate_fn, 103 | persistent_workers=self.hparams.num_workers > 0, 104 | shuffle=True, 105 | ) 106 | 107 | def val_dataloader(self) -> EVAL_DATALOADERS: 108 | val_dataloaders = [ 109 | DataLoader( 110 | dataset=self.datasets[x], 111 | batch_size=self.hparams.batch_size, 112 | num_workers=self.hparams.num_workers, 113 | pin_memory=self.hparams.pin_memory, 114 | collate_fn=self.collate_fn, 115 | persistent_workers=self.hparams.num_workers > 0, 116 | shuffle=False, 117 | ) 118 | for x in self.val_splits 119 | ] 120 | 121 | return val_dataloaders[0] if len(val_dataloaders) == 1 else val_dataloaders 122 | 123 | def test_dataloader(self) -> EVAL_DATALOADERS: 124 | test_dataloaders = [ 125 | DataLoader( 126 | dataset=self.datasets[x], 127 | batch_size=self.hparams.batch_size, 128 | num_workers=self.hparams.num_workers, 129 | pin_memory=self.hparams.pin_memory, 130 | collate_fn=self.collate_fn, 131 | persistent_workers=self.hparams.num_workers > 0, 132 | shuffle=False, 133 | ) 134 | for x in self.test_splits 135 | ] 136 | 137 | return test_dataloaders[0] if len(test_dataloaders) == 1 else test_dataloaders 138 | 139 | @staticmethod 140 | def _preprocess(batch, text_fields): 141 | if len(text_fields) > 1: 142 | text = list(zip(batch[text_fields[0]], batch[text_fields[1]], strict=True)) 143 | else: 144 | text = batch[text_fields[0]] 145 | labels = batch["label"] 146 | 147 | return {"text": text, "labels": labels} 148 | -------------------------------------------------------------------------------- /src/datamodules/mnist_datamodule.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS 3 | from torch.utils.data import DataLoader, random_split 4 | from torchvision.datasets import MNIST 5 | from torchvision.transforms import transforms 6 | 7 | 8 | class MNISTDataModule(L.LightningDataModule): 9 | def __init__( 10 | self, 11 | data_dir: str = "data/", 12 | batch_size: int = 64, 13 | num_workers: int = 0, 14 | pin_memory: bool = False, 15 | ): 16 | super().__init__() 17 | self.save_hyperparameters() 18 | 19 | self.transforms = transforms.ToTensor() 20 | self.data = {} 21 | 22 | def prepare_data(self) -> None: 23 | MNIST(self.hparams.data_dir, train=True, download=True) 24 | MNIST(self.hparams.data_dir, train=False, download=True) 25 | 26 | def setup(self, stage: str | None = None) -> None: 27 | if not self.data: 28 | dataset = MNIST( 29 | self.hparams.data_dir, train=True, transform=self.transforms 30 | ) 31 | self.data["train"], self.data["val"] = random_split(dataset, [55000, 5000]) 32 | 33 | self.data["test"] = MNIST( 34 | self.hparams.data_dir, train=False, transform=self.transforms 35 | ) 36 | 37 | def train_dataloader(self) -> TRAIN_DATALOADERS: 38 | return DataLoader( 39 | dataset=self.data["train"], 40 | batch_size=self.hparams.batch_size, 41 | num_workers=self.hparams.num_workers, 42 | pin_memory=self.hparams.pin_memory, 43 | persistent_workers=self.hparams.num_workers > 0, 44 | shuffle=True, 45 | ) 46 | 47 | def val_dataloader(self) -> EVAL_DATALOADERS: 48 | return DataLoader( 49 | dataset=self.data["val"], 50 | batch_size=self.hparams.batch_size, 51 | num_workers=self.hparams.num_workers, 52 | pin_memory=self.hparams.pin_memory, 53 | persistent_workers=self.hparams.num_workers > 0, 54 | shuffle=False, 55 | ) 56 | 57 | def test_dataloader(self) -> EVAL_DATALOADERS: 58 | return DataLoader( 59 | dataset=self.data["test"], 60 | batch_size=self.hparams.batch_size, 61 | num_workers=self.hparams.num_workers, 62 | pin_memory=self.hparams.pin_memory, 63 | persistent_workers=self.hparams.num_workers > 0, 64 | shuffle=False, 65 | ) 66 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .glue_transformer import GLUETransformer 2 | from .mnist_model import MNISTModel 3 | 4 | __all__ = ["GLUETransformer", "MNISTModel"] 5 | -------------------------------------------------------------------------------- /src/models/glue_transformer.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any 3 | 4 | import evaluate 5 | import lightning as L 6 | import torch 7 | from lightning.pytorch.utilities.types import STEP_OUTPUT 8 | from transformers import ( 9 | AutoModelForSequenceClassification, 10 | AutoTokenizer, 11 | PreTrainedTokenizer, 12 | get_scheduler, 13 | ) 14 | 15 | 16 | class GLUETransformer(L.LightningModule): 17 | def __init__( 18 | self, 19 | task_name: str, 20 | model_name_or_path: str, 21 | num_labels: int, 22 | max_length: int | None = None, 23 | weight_decay: float = 0.0, 24 | learning_rate: float = 2e-5, 25 | scheduler_type: str = "linear", 26 | warmup_steps: int = 0, 27 | ): 28 | super().__init__() 29 | self.save_hyperparameters() 30 | 31 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 32 | self.convert_to_features = partial( 33 | self._convert_to_features, tokenizer=tokenizer, max_length=max_length 34 | ) 35 | self.model = AutoModelForSequenceClassification.from_pretrained( 36 | model_name_or_path 37 | ) 38 | self.metric = evaluate.load("glue", task_name) 39 | 40 | self.validation_step_outputs = [] 41 | self.test_step_outputs = [] 42 | 43 | def forward(self, batch): 44 | return self.model.forward(**batch) 45 | 46 | def shared_step(self, batch) -> STEP_OUTPUT | None: 47 | output = self.forward(batch) 48 | loss, logits = output.loss, output.logits 49 | labels = batch["labels"] 50 | 51 | if self.hparams.num_labels >= 1: 52 | preds = torch.argmax(logits, dim=1) 53 | elif self.hparams.num_labels == 1: 54 | preds = logits.squeeze() 55 | 56 | return {"loss": loss, "preds": preds, "labels": labels} 57 | 58 | def training_step( 59 | self, batch, batch_idx: int, dataloader_idx: int | None = None 60 | ) -> STEP_OUTPUT: 61 | return self.shared_step(batch) 62 | 63 | def validation_step( 64 | self, batch, batch_idx: int, dataloader_idx: int | None = None 65 | ) -> STEP_OUTPUT | None: 66 | output = self.shared_step(batch) 67 | self.validation_step_outputs.append(output) 68 | return output 69 | 70 | def test_step( 71 | self, batch, batch_idx: int, dataloader_idx: int | None = None 72 | ) -> STEP_OUTPUT | None: 73 | output = self.shared_step(batch) 74 | self.test_step_outputs.append(output) 75 | return output 76 | 77 | def shared_epoch_end(self, outputs, step: str) -> None: 78 | if hasattr(self.trainer.datamodule, f"{step}_splits"): 79 | splits = getattr(self.trainer.datamodule, f"{step}_splits") 80 | if len(splits) > 1: 81 | for i, output in enumerate(outputs): 82 | split = splits[i].split("_")[-1] 83 | preds = torch.cat([x["preds"] for x in output]) 84 | labels = torch.cat([x["labels"] for x in output]) 85 | loss = torch.stack([x["loss"] for x in output]).mean() 86 | 87 | split_metrics = { 88 | f"{step}/{split}_{k}": v 89 | for k, v in self.metric.compute( 90 | predictions=preds, references=labels 91 | ).items() 92 | } 93 | 94 | self.log(f"{step}/{split}_loss", loss) 95 | self.log_dict(split_metrics, prog_bar=True) 96 | 97 | return loss 98 | 99 | preds = torch.cat([x["preds"] for x in outputs]) 100 | labels = torch.cat([x["labels"] for x in outputs]) 101 | loss = torch.stack([x["loss"] for x in outputs]).mean() 102 | 103 | metrics = { 104 | f"{step}/{k}": v 105 | for k, v in self.metric.compute( 106 | predictions=preds, references=labels 107 | ).items() 108 | } 109 | 110 | self.log(f"{step}/loss", loss) 111 | self.log_dict(metrics, prog_bar=True) 112 | 113 | def on_training_epoch_end(self) -> None: 114 | self.shared_epoch_end(self.training_step_outputs, "train") 115 | 116 | def on_validation_epoch_end(self) -> None: 117 | self.shared_epoch_end(self.validation_step_outputs, "val") 118 | self.validation_step_outputs.clear() 119 | 120 | def on_test_epoch_end(self) -> None: 121 | self.shared_epoch_end(self.test_step_outputs, "test") 122 | self.test_step_outputs.clear() 123 | 124 | def configure_optimizers(self): 125 | no_decay = ["bias", "LayerNorm.weight"] 126 | optimizer_grouped_parameters = [ 127 | { 128 | "params": [ 129 | p 130 | for n, p in self.named_parameters() 131 | if not any(nd in n for nd in no_decay) 132 | ], 133 | "weight_decay": self.hparams.weight_decay, 134 | }, 135 | { 136 | "params": [ 137 | p 138 | for n, p in self.named_parameters() 139 | if any(nd in n for nd in no_decay) 140 | ], 141 | "weight_decay": 0.0, 142 | }, 143 | ] 144 | optimizer = torch.optim.AdamW( 145 | optimizer_grouped_parameters, 146 | lr=self.hparams.learning_rate, 147 | ) 148 | 149 | scheduler = get_scheduler( 150 | self.hparams.scheduler_type, 151 | optimizer, 152 | num_warmup_steps=self.hparams.warmup_steps, 153 | num_training_steps=self.trainer.estimated_stepping_batches, 154 | ) 155 | scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} 156 | 157 | return [optimizer], [scheduler] 158 | 159 | @staticmethod 160 | def _convert_to_features( 161 | batch: dict[str, list] | list[Any], 162 | tokenizer: PreTrainedTokenizer, 163 | max_length: int | None = None, 164 | ) -> dict | Any: 165 | features = tokenizer( 166 | batch["text"], 167 | padding="max_length", 168 | truncation=True, 169 | max_length=max_length, 170 | ) 171 | features["labels"] = batch["labels"] 172 | 173 | return features 174 | -------------------------------------------------------------------------------- /src/models/mnist_model.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | import torch 3 | import torch.nn.functional as F 4 | from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint 5 | from lightning.pytorch.utilities.types import STEP_OUTPUT 6 | from torchmetrics import Accuracy, MetricCollection 7 | 8 | 9 | class MNISTModel(L.LightningModule): 10 | def __init__( 11 | self, 12 | input_size: int = 28 * 28, 13 | hidden_dim: int = 128, 14 | output_size: int = 10, 15 | learning_rate: float = 1e-3, 16 | ): 17 | super().__init__() 18 | self.save_hyperparameters() 19 | 20 | self.l1 = torch.nn.Linear(input_size, hidden_dim) 21 | self.l2 = torch.nn.Linear(hidden_dim, output_size) 22 | 23 | metrics = MetricCollection({"acc": Accuracy(task="multiclass", num_classes=10)}) 24 | self.train_metrics = metrics.clone(prefix="train/") 25 | self.val_metrics = metrics.clone(prefix="val/") 26 | self.test_metrics = metrics.clone(prefix="test/") 27 | 28 | def forward(self, x): 29 | x = x.view(x.size(0), -1) 30 | x = torch.relu(self.l1(x)) 31 | x = torch.relu(self.l2(x)) 32 | 33 | return x 34 | 35 | def shared_step(self, batch, step: str) -> STEP_OUTPUT | None: 36 | x, y = batch 37 | logits = self.forward(x) 38 | loss = F.cross_entropy(logits, y) 39 | 40 | preds = torch.argmax(logits, dim=1) 41 | metrics = getattr(self, f"{step}_metrics") 42 | metrics(preds, y) 43 | 44 | self.log(f"{step}/loss", loss) 45 | self.log_dict(metrics, prog_bar=True) 46 | 47 | return loss 48 | 49 | def training_step(self, batch, batch_idx: int) -> STEP_OUTPUT: 50 | return self.shared_step(batch, "train") 51 | 52 | def validation_step(self, batch, batch_idx: int) -> STEP_OUTPUT | None: 53 | return self.shared_step(batch, "val") 54 | 55 | def test_step(self, batch, batch_idx: int) -> STEP_OUTPUT | None: 56 | return self.shared_step(batch, "test") 57 | 58 | def configure_optimizers(self): 59 | return torch.optim.Adam(params=self.parameters(), lr=self.hparams.learning_rate) 60 | 61 | def configure_callbacks(self): 62 | callbacks_kargs = {"monitor": "val/acc", "mode": "max"} 63 | early_stopping = EarlyStopping(patience=5, **callbacks_kargs) 64 | model_checkpoint = ModelCheckpoint(**callbacks_kargs) 65 | return [early_stopping, model_checkpoint] 66 | -------------------------------------------------------------------------------- /src/models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tshu-w/lightning-template/810508e9e345dde2f377bd3e8d6471ce585736c0/src/models/modules/__init__.py -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tshu-w/lightning-template/810508e9e345dde2f377bd3e8d6471ce585736c0/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/lit_cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections.abc import Iterable 3 | 4 | from lightning.pytorch.cli import LightningArgumentParser, LightningCLI 5 | 6 | 7 | class LitCLI(LightningCLI): 8 | def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: 9 | for arg in ["num_labels", "task_name"]: 10 | parser.link_arguments( 11 | f"data.init_args.{arg}", 12 | f"model.init_args.{arg}", 13 | apply_on="instantiate", 14 | ) 15 | 16 | def before_instantiate_classes(self) -> None: 17 | config = self.config[self.subcommand] 18 | 19 | default_root_dir = config.trainer.default_root_dir 20 | logger = config.trainer.logger 21 | if logger and logger is not True: 22 | loggers = logger if isinstance(logger, Iterable) else [logger] 23 | for logger in loggers: 24 | logger.init_args.save_dir = os.path.join( 25 | default_root_dir, self.subcommand 26 | ) 27 | 28 | 29 | def lit_cli(): 30 | LitCLI( 31 | parser_kwargs={ 32 | cmd: { 33 | "default_config_files": ["configs/presets/default.yaml"], 34 | } 35 | for cmd in ["fit", "validate", "test"] 36 | }, 37 | save_config_kwargs={"overwrite": True}, 38 | ) 39 | 40 | 41 | if __name__ == "__main__": 42 | lit_cli() 43 | -------------------------------------------------------------------------------- /src/utils/loggers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import lightning as L 4 | from lightning.fabric.utilities.types import _PATH 5 | from lightning.pytorch.callbacks import ModelCheckpoint 6 | 7 | 8 | # TODO: 9 | # https://github.com/Lightning-AI/lightning/issues/14188 10 | # https://github.com/Lightning-AI/lightning/pull/14640 11 | @property 12 | def log_dir(self) -> str: 13 | if self.loggers and self.loggers[0].log_dir is not None: 14 | dirpath = self.loggers[0].log_dir 15 | else: 16 | dirpath = self.default_root_dir 17 | 18 | dirpath = self.strategy.broadcast(dirpath) 19 | return dirpath 20 | 21 | 22 | L.Trainer.log_dir = log_dir 23 | 24 | 25 | def __resolve_ckpt_dir(self, trainer: L.Trainer) -> _PATH: 26 | """Determines model checkpoint save directory at runtime. References attributes from the trainer's logger 27 | to determine where to save checkpoints. The base path for saving weights is set in this priority: 28 | 1. Checkpoint callback's path (if passed in) 29 | 2. The default_root_dir from trainer if trainer has no logger 30 | 3. The log_dir from trainer, if trainer has logger 31 | """ 32 | if self.dirpath is not None: 33 | # short circuit if dirpath was passed to ModelCheckpoint 34 | return self.dirpath 35 | if trainer.loggers: 36 | ckpt_path = os.path.join(trainer.log_dir, "checkpoints") 37 | else: 38 | ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints") 39 | return ckpt_path 40 | 41 | 42 | ModelCheckpoint._ModelCheckpoint__resolve_ckpt_dir = __resolve_ckpt_dir 43 | -------------------------------------------------------------------------------- /src/utils/sweep_cli.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import math 4 | import os 5 | import shlex 6 | import subprocess 7 | from pathlib import Path 8 | from typing import Any, Literal 9 | 10 | import ray 11 | from jsonargparse import CLI 12 | from ray import train, tune 13 | 14 | ray.init(_temp_dir=str(Path.home() / ".cache" / "ray")) 15 | 16 | 17 | def run_cli(config, debug: bool = True, command: str = "fit", devices: int = 1): 18 | os.chdir(os.environ["TUNE_ORIG_WORKING_DIR"]) 19 | 20 | argv = ["./run", command] 21 | ckpt_path = config.pop("ckpt_path", None) 22 | if ckpt_path is not None: 23 | config_path = Path(ckpt_path).parents[1] / "config.yaml" 24 | argv.extend(["--config", str(config_path)]) 25 | argv.extend(["--ckpt_path", ckpt_path]) 26 | config.pop("config", None) 27 | config.pop("data_config", None) 28 | else: 29 | for cfg in ["config", "data_config"]: 30 | if cfg in config: 31 | argv.extend(["--config", config.pop(cfg)]) 32 | 33 | argv.extend( 34 | itertools.chain( 35 | *[ 36 | [f"--{k}", v if isinstance(v, str) else json.dumps(v)] 37 | for k, v in config.items() 38 | ] 39 | ) 40 | ) 41 | 42 | argv.extend(["--trainer.devices", str(devices)]) 43 | if debug: 44 | argv.extend(["--config", "configs/presets/tester.yaml"]) 45 | 46 | print(shlex.join(argv)) 47 | subprocess.check_output(argv) 48 | 49 | 50 | def sweep( 51 | command: Literal["fit", "validate", "test"], 52 | debug: bool = False, 53 | gpus_per_trial: int | float = 1, 54 | *, 55 | ckpt_paths: list[str | None] | None = None, 56 | configs: list[str] | None = None, 57 | data_configs: list[str | None] | None = None, 58 | override_kwargs: dict[str, Any] | None = None, 59 | ): 60 | param_space = { 61 | **({"ckpt_path": tune.grid_search(ckpt_paths)} if ckpt_paths else {}), 62 | **({"config": tune.grid_search(configs)} if configs else {}), 63 | **({"data_config": tune.grid_search(data_configs)} if data_configs else {}), 64 | **( 65 | { 66 | k: tune.grid_search(v) if isinstance(v, list) else tune.grid_search([v]) 67 | for k, v in override_kwargs.items() 68 | } 69 | if override_kwargs 70 | else {} 71 | ), 72 | } 73 | 74 | tune_config = tune.TuneConfig() 75 | run_config = train.RunConfig( 76 | log_to_file=True, 77 | storage_path=Path("./results/ray").resolve(), 78 | ) 79 | trainable = tune.with_parameters( 80 | run_cli, 81 | debug=debug, 82 | command=command, 83 | devices=math.ceil(gpus_per_trial), 84 | ) 85 | tuner = tune.Tuner( 86 | tune.with_resources(trainable, resources={"gpu": gpus_per_trial}), 87 | param_space=param_space, 88 | tune_config=tune_config, 89 | run_config=run_config, 90 | ) 91 | tuner.fit() 92 | 93 | 94 | def fit(*args, **kwargs): 95 | sweep("fit", *args, **kwargs) 96 | 97 | 98 | def validate(*args, **kwargs): 99 | sweep("validate", *args, **kwargs) 100 | 101 | 102 | def test(*args, **kwargs): 103 | sweep("test", *args, **kwargs) 104 | 105 | 106 | def sweep_cli(): 107 | CLI([fit, validate, test]) 108 | 109 | 110 | if __name__ == "__main__": 111 | sweep_cli() 112 | -------------------------------------------------------------------------------- /src/vendor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tshu-w/lightning-template/810508e9e345dde2f377bd3e8d6471ce585736c0/src/vendor/__init__.py --------------------------------------------------------------------------------