├── .envrc
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── configs
├── mnist.yaml
├── mrpc.yaml
├── presets
│ ├── ddp.yaml
│ ├── debugger.yaml
│ ├── default.yaml
│ ├── limiter.yaml
│ ├── overfitter.yaml
│ ├── profiler.yaml
│ └── tester.yaml
└── sweep_mnist.yaml
├── data
└── .gitkeep
├── environment.yaml
├── models
└── .gitkeep
├── notebooks
└── .gitkeep
├── pyproject.toml
├── requirements.txt
├── results
└── .gitkeep
├── run
├── scripts
├── .gitkeep
├── print_results
├── run.sh
├── sweep
└── sweep_mnist.sh
└── src
├── __init__.py
├── callbacks
├── __init__.py
└── metric.py
├── datamodules
├── __init__.py
├── datasets
│ └── __init__.py
├── glue_datamodule.py
└── mnist_datamodule.py
├── models
├── __init__.py
├── glue_transformer.py
├── mnist_model.py
└── modules
│ └── __init__.py
├── utils
├── __init__.py
├── lit_cli.py
├── loggers.py
└── sweep_cli.py
└── vendor
└── __init__.py
/.envrc:
--------------------------------------------------------------------------------
1 | layout conda lit-template
2 | export PATH=$PWD:$PWD/scripts:$PATH
3 | export PYTHONPATH=$PWD${PYTHONPATH:+":$PYTHONPATH"}
4 | [ -f .env ] && dotenv
5 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ### Python ###
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | pip-wheel-metadata/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # pipenv
89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
92 | # install all needed dependencies.
93 | #Pipfile.lock
94 |
95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96 | __pypackages__/
97 |
98 | # Celery stuff
99 | celerybeat-schedule
100 | celerybeat.pid
101 |
102 | # SageMath parsed files
103 | *.sage.py
104 |
105 | # Environments
106 | .env
107 | .venv
108 | env/
109 | venv/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 |
132 | ### Lightning-Template ###
133 | /data/*
134 | !/data/.gitkeep
135 | /models/*
136 | !/models/.gitkeep
137 | /results/*
138 | !/results/.gitkeep
139 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v5.0.0
6 | hooks:
7 | - id: check-executables-have-shebangs
8 | - id: check-json
9 | - id: check-shebang-scripts-are-executable
10 | - id: check-yaml
11 | - id: detect-private-key
12 | - id: end-of-file-fixer
13 | - id: trailing-whitespace
14 |
15 | - repo: https://github.com/charliermarsh/ruff-pre-commit
16 | rev: v0.11.8
17 | hooks:
18 | - id: ruff
19 | - id: ruff-format
20 |
21 | - repo: https://github.com/kynan/nbstripout.git
22 | rev: 0.8.1
23 | hooks:
24 | - id: nbstripout
25 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Tianshu Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Lightning-Template
2 |
3 | [](https://github.com/tshu-w/lightning-template)
4 | [](https://pytorch.org)
5 | [](https://lightning.ai)
6 | [](https://github.com/astral-sh/ruff)
7 | [](https://github.com/tshu-w/lightning-template?tab=MIT-1-ov-file)
8 |
9 | A clean and flexible [Pytorch Lightning](https://github.com/Lightning-AI/pytorch-lightning) template to kickstart and structure your deep learning project, ensuring efficient workflow, reproducibility, and easy extensibility for rapid experiments.
10 |
11 | ### Why Lightning-Template?
12 |
13 | Pytorch Lightning is a deep learning framework designed for professional AI researchers and engineers, freeing users from boilerplate code (_e.g._, multiple GPUs/TPUs/HPUs training, early stopping, and checkpointing) to focus on going from idea to paper/production.
14 |
15 | This Lightning template leverages [Lightning CLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html) to separate configuration from source code, guaranteeing reproducibility of experiments, and incorporates many other [best practices](#best-practices).
16 |
17 | + **Compared to [Lightning-Hydra-Template](https://github.com/ashleve/lightning-hydra-template)**: Our template provides similar functionality through a simple and straightforward encapsulation of Lightning's built-in CLI, making it suitable for users who prefer minimal setup without an additional Hybra layer.
18 |
19 | > Note: This is an unofficial project that lacks comprehensive test and continuous integration.
20 |
21 | ### Quickstart
22 |
23 | ```console
24 | git clone https://github.com/YourGithubName/your-repository-name
25 | cd your-repository-name
26 |
27 | # [SUGGESTED] use conda environment
28 | conda env create -n env-name -f environment.yaml
29 | conda activate env-name
30 |
31 | # [ALTERNATIVE] install requirements directly
32 | pip install -r requirements.txt
33 |
34 | # Run the sample script, i.e., ./run fit --config configs/mnist.yaml
35 | bash -x scripts/run.sh
36 | ```
37 |
38 | ### Workflow - how it works
39 |
40 | Before using this template, please read the basic Pytorch Lightning documentation: [Lightning in 15 minutes](https://lightning.ai/docs/pytorch/stable/starter/introduction.html).
41 |
42 | 1. Define a [Lightning Module](https://lightning.ai/docs/pytorch/2.4.0/common/lightning_module.html) (Examples: [mnist_model.py](src/models/mnist_model.py) and [glue_transformer.py](src/models/glue_transformer.py))
43 | 2. Define a [Lightning DataModule](https://lightning.ai/docs/pytorch/2.4.0/data/datamodule.html#lightningdatamodule) (Examples: [mnist_datamodule.py](src/datamodules/mnist_datamodule.py) and [glue_datamodule.py](src/datamodules/glue_datamodule.py))
44 | 3. Prepare your experiment configs (Examples: [mnist.yaml](configs/mnist.yaml) and [mrpc.yaml](configs/mrpc.yaml))
45 | 4. Run experiments (_cf._, [Configure hyperparameters from the CLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html))
46 | - To see the available commands type:
47 | ```console
48 | ./run --help
49 | ```
50 | - Train a model from the config:
51 | ```console
52 | ./run fit --config configs/mnist.yaml
53 | ```
54 | - Override config options:
55 | ```console
56 | ./run fit --config configs/mnist.yaml --trainer.precision 16 --model.learning_rate 0.1 --data.batch_size 64
57 | ```
58 | - Separate model and datamodule configs:
59 | ```console
60 | ./run fit --config configs/data.yaml --config configs/model.yaml
61 | ```
62 |
63 | ### Project Structure
64 | The directory structure of a project looks like this:
65 | ```
66 | lightning-template
67 | ├── configs ← Directory of Configs
68 | │ ├── mnist.yaml
69 | │ ├── mrpc.yaml
70 | │ ├── presets ← Preset configs for Lightning features
71 | │ └── sweep_mnist.yaml
72 | ├── data ← Directory of Data
73 | ├── environment.yaml
74 | ├── models ← Directory of Models
75 | ├── notebooks ← Directory of Notebooks
76 | ├── pyproject.toml
77 | ├── README.md
78 | ├── requirements.txt
79 | ├── results ← Directory of Results
80 | ├── run ← Script to Run Lightning CLI
81 | ├── scripts ← Directory of Scripts
82 | │ ├── print_results
83 | │ ├── run.sh
84 | │ ├── sweep ← Script to sweep Experiments
85 | │ └── sweep_mnist.sh
86 | └── src ← Directory of Source Code
87 | ├── callbacks
88 | ├── datamodules
89 | ├── models
90 | ├── utils
91 | └── vendor ← Directory of Third-Party Code
92 | ```
93 |
94 | ### Best Practices
95 | 1. Use [conda](https://docs.anaconda.com/miniconda/) to manage environments.
96 | 2. Leverages Lightning awesome features (_cf._, [How-to Guides](https://lightning.ai/docs/pytorch/stable/common/) & [Glossary](https://lightning.ai/docs/pytorch/stable/glossary/))
97 | 3. Use [pre-commit](https://pre-commit.com) and [ruff](https://docs.astral.sh/ruff) to check and format code with configuration in [pyproject.toml](pyproject.toml) and [.pre-commit-config.yaml](.pre-commit-config.yaml).
98 | ```console
99 | pre-commit install
100 | ```
101 | 4. Use [dotenv](https://github.com/motdotla/dotenv) to automatically change environments and set variables (_cf._, [.envrc](.envrc)).
102 | ```console
103 | λ cd lightning-template
104 | direnv: loading ~/lightning-template/.envrc
105 | direnv: export +CONDA_DEFAULT_ENV +CONDA_EXE +CONDA_PREFIX +CONDA_PROMPT_MODIFIER +CONDA_PYTHON_EXE +CONDA_SHLVL +_CE_CONDA +_CE_M ~PATH ~PYTHONPATH
106 | ```
107 | 1. Add the project root to `PATH` to use `run` script directly.
108 | ```console
109 | export PATH=$PWD:$PWD/scripts:$PATH
110 | run fit --config configs/mnist.yaml
111 | ```
112 | 2. Add the project root to `PYTHONPATH` to avoid modifying `sys.path` in scripts.
113 | ```console
114 | export PYTHONPATH=$PWD${PYTHONPATH:+":$PYTHONPATH"}
115 | ```
116 | 3. Save privacy variable to `.env`.
117 | 5. Use [shtab](https://jsonargparse.readthedocs.io/en/stable/#tab-completion) to generate shell completion file.
118 |
119 | 6. Use [ray tune](https://docs.ray.io/en/latest/tune/index.html) to sweep parameters or hyperparameter search (_cf._, [sweep_cli.py](src/utils/sweep_cli.py)).
120 | ```console
121 | bash ./scripts/sweep --config configs/sweep_mnist.yaml
122 | ```
123 | 7. Use third-party logger (_e.g._, [w&b](https://wandb.ai) and [aim](https://aimstack.io)) to track experiments.
124 |
125 | ### DELETE EVERYTHING ABOVE FOR YOUR PROJECT
126 |
127 | ---
128 |
129 |
130 |
131 |
Your Project Name
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 | ## Description
141 | What it does
142 |
143 | ## How to run
144 | First, install dependencies
145 | ```console
146 | # clone project
147 | git clone https://github.com/YourGithubName/your-repository-name
148 | cd your-repository-name
149 |
150 | # [SUGGESTED] use conda environment
151 | conda env create -f environment.yaml
152 | conda activate lit-template
153 |
154 | # [ALTERNATIVE] install requirements directly
155 | pip install -r requirements.txt
156 | ```
157 |
158 | Next, to obtain the main results of the paper:
159 | ```console
160 | # commands to get the main results
161 | ```
162 |
163 | You can also run experiments with the `run` script.
164 | ```console
165 | # fit with the demo config
166 | ./run fit --config configs/demo.yaml
167 | # or specific command line arguments
168 | ./run fit --model MNISTModel --data MNISTDataModule --data.batch_size 32 --trainer.gpus 0
169 |
170 | # evaluate with the checkpoint
171 | ./run test --config configs/demo.yaml --ckpt_path ckpt_path
172 |
173 | # get the script help
174 | ./run --help
175 | ./run fit --help
176 | ```
177 |
178 | ## Citation
179 | ```
180 | @article{YourName,
181 | title={Your Title},
182 | author={Your team},
183 | journal={Location},
184 | year={Year}
185 | }
186 | ```
187 |
--------------------------------------------------------------------------------
/configs/mnist.yaml:
--------------------------------------------------------------------------------
1 | seed_everything: 123
2 | trainer:
3 | max_epochs: 20
4 | model:
5 | class_path: MNISTModel
6 | data:
7 | class_path: MNISTDataModule
8 |
--------------------------------------------------------------------------------
/configs/mrpc.yaml:
--------------------------------------------------------------------------------
1 | seed_everything: 123
2 | trainer:
3 | max_epochs: 30
4 | model:
5 | class_path: GLUETransformer
6 | init_args:
7 | model_name_or_path: bert-base-uncased
8 | max_length: 256
9 | data:
10 | class_path: GLUEDataModule
11 | init_args:
12 | task_name: mrpc
13 | batch_size: 32
14 | num_workers: 0
15 | pin_memory: true
16 |
--------------------------------------------------------------------------------
/configs/presets/ddp.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | devices: 4
3 | # strategy: ddp
4 | strategy: ddp_find_unused_parameters_false
5 |
--------------------------------------------------------------------------------
/configs/presets/debugger.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | accelerator: "cpu"
3 | detect_anomaly: true
4 | fast_dev_run: 1
5 |
--------------------------------------------------------------------------------
/configs/presets/default.yaml:
--------------------------------------------------------------------------------
1 | seed_everything: 42
2 | trainer:
3 | default_root_dir: results
4 | callbacks:
5 | class_path: Metric
6 | logger:
7 | class_path: CSVLogger
8 | init_args:
9 | save_dir: results
10 | accelerator: auto
11 | devices: 1
12 |
--------------------------------------------------------------------------------
/configs/presets/limiter.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | max_epochs: 5
3 | limit_train_batches: 10
4 | limit_val_batches: 5
5 | limit_test_batches: 5
6 | logger: false
7 |
--------------------------------------------------------------------------------
/configs/presets/overfitter.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | overfit_batches: 0.01
3 | logger: false
4 |
--------------------------------------------------------------------------------
/configs/presets/profiler.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | max_epochs: 1
3 | profiler: "simple"
4 | # profiler: "advanced"
5 | # profiler: "pytorch"
6 |
--------------------------------------------------------------------------------
/configs/presets/tester.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | fast_dev_run: 5
3 | logger: false
4 |
--------------------------------------------------------------------------------
/configs/sweep_mnist.yaml:
--------------------------------------------------------------------------------
1 | fit:
2 | debug: false
3 | gpus_per_trial: 1
4 | configs:
5 | - configs/mnist.yaml
6 | override_kwargs:
7 | seed_everything:
8 | - 123
9 | - 42
10 | trainer.max_epochs: 5
11 | data.batch_size:
12 | - 64
13 |
--------------------------------------------------------------------------------
/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tshu-w/lightning-template/810508e9e345dde2f377bd3e8d6471ce585736c0/data/.gitkeep
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: lit-template
2 | channels:
3 | - defaults
4 | - conda-forge
5 | - pytorch
6 | dependencies:
7 | - python=3.12
8 | - pip
9 | - pip:
10 | - --requirement requirements.txt
11 |
--------------------------------------------------------------------------------
/models/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tshu-w/lightning-template/810508e9e345dde2f377bd3e8d6471ce585736c0/models/.gitkeep
--------------------------------------------------------------------------------
/notebooks/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tshu-w/lightning-template/810508e9e345dde2f377bd3e8d6471ce585736c0/notebooks/.gitkeep
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # https://github.com/microsoft/pyright
2 | [tool.pyright]
3 | include = ["src"]
4 | venv = "lit-template"
5 | typeCheckingMode = "off"
6 | useLibraryCodeForTypes = true
7 |
8 | # https://github.com/charliermarsh/ruff
9 | [tool.ruff]
10 | fix = true
11 | line-length = 88
12 | target-version = "py310"
13 | [tool.ruff.lint]
14 | select = [
15 | "E", # pycodestyle
16 | "F", # Pyflakes
17 | "UP", # pyupgrade
18 | "B", # flake8-bugbear
19 | "SIM", # flake8-simplify
20 | "I", # isort
21 | ]
22 | ignore = ["E501"]
23 | # https://github.com/timothycrosley/isort/
24 | [tool.ruff.lint.isort]
25 | combine-as-imports = true
26 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch >= 2.4.0
2 | lightning >= 2.4.0
3 | torchvision
4 | jsonargparse[signatures] # for CLI
5 | ray[tune]
6 |
7 | transformers
8 | scikit-learn
9 | datasets
10 | evaluate
11 |
12 | # dev tools
13 | jupyterlab
14 | shtab
15 |
--------------------------------------------------------------------------------
/results/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tshu-w/lightning-template/810508e9e345dde2f377bd3e8d6471ce585736c0/results/.gitkeep
--------------------------------------------------------------------------------
/run:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | cd "$(dirname "$0")"
3 |
4 | python -m src.utils.lit_cli "$@"
5 |
--------------------------------------------------------------------------------
/scripts/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tshu-w/lightning-template/810508e9e345dde2f377bd3e8d6471ce585736c0/scripts/.gitkeep
--------------------------------------------------------------------------------
/scripts/print_results:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Modified from: https://github.com/allenai/allennlp/blob/main/allennlp/commands/print_results.py
3 |
4 | import argparse
5 | import json
6 | from pathlib import Path
7 | from signal import SIG_DFL, SIGPIPE, signal
8 |
9 | signal(SIGPIPE, SIG_DFL)
10 |
11 |
12 | def main(args: argparse.Namespace):
13 | """
14 | Prints results from an `argparse.Namespace` object.
15 | """
16 | path = args.path
17 | metrics_name = args.metrics_filename
18 | keys = args.keys
19 |
20 | results_dict = {}
21 | for f in path.rglob(metrics_name):
22 | with open(f) as file_:
23 | metrics = json.load(file_)
24 | name = f.parents[0].relative_to(f.parents[2])
25 | results_dict[name] = metrics
26 |
27 | sorted_keys = sorted(list(results_dict.keys()))
28 | print(f"{path.name}, {', '.join(keys)}")
29 | for name in sorted_keys:
30 | results = results_dict[name]
31 | keys_to_print = (str(results.get(key, "N/A")) for key in keys)
32 | print(f"{name}, {', '.join(keys_to_print)}")
33 |
34 |
35 | if __name__ == "__main__":
36 | parser = argparse.ArgumentParser(
37 | description="Print experiment results in a helpful CSV format."
38 | )
39 |
40 | parser.add_argument(
41 | "path",
42 | type=Path,
43 | help="Path to recursively search for experiment directories.",
44 | )
45 | parser.add_argument(
46 | "-k",
47 | "--keys",
48 | type=str,
49 | nargs="+",
50 | help='Keys to print from metrics.json. Keys not present in all metrics.json will result in "N/A"',
51 | default=[],
52 | required=False,
53 | )
54 | parser.add_argument(
55 | "-m",
56 | "--metrics-filename",
57 | type=str,
58 | help="Name of the metrics file to inspect.",
59 | default="metrics.json",
60 | required=False,
61 | )
62 |
63 | args = parser.parse_args()
64 | main(args)
65 |
--------------------------------------------------------------------------------
/scripts/run.sh:
--------------------------------------------------------------------------------
1 | ./run fit --config configs/mnist.yaml
2 |
--------------------------------------------------------------------------------
/scripts/sweep:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | cd "$(dirname $(dirname "$0"))"
3 |
4 | python -m src.utils.sweep_cli "$@"
5 |
--------------------------------------------------------------------------------
/scripts/sweep_mnist.sh:
--------------------------------------------------------------------------------
1 | bash ./scripts/sweep --config configs/sweep_mnist.yaml
2 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | from . import callbacks, datamodules, models
2 | from .utils import loggers
3 |
4 | __all__ = ["callbacks", "datamodules", "models", "loggers"]
5 |
--------------------------------------------------------------------------------
/src/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | from .metric import Metric
2 |
3 | __all__ = ["Metric"]
4 |
--------------------------------------------------------------------------------
/src/callbacks/metric.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | from pathlib import Path
4 |
5 | import lightning as L
6 | from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
7 | from lightning.pytorch.trainer.states import TrainerFn
8 |
9 |
10 | class Metric(L.Callback):
11 | r"""
12 | Save logged metrics to ``Trainer.log_dir``.
13 | """
14 |
15 | def teardown(
16 | self,
17 | trainer: L.Trainer,
18 | pl_module: L.LightningModule,
19 | stage: str | None = None,
20 | ) -> None:
21 | metrics = {}
22 | if stage == TrainerFn.FITTING:
23 | if (
24 | trainer.checkpoint_callback
25 | and trainer.checkpoint_callback.best_model_path
26 | ):
27 | ckpt_path = trainer.checkpoint_callback.best_model_path
28 | # inhibit disturbing logging
29 | logging.getLogger("lightning.pytorch.utilities.distributed").setLevel(
30 | logging.WARNING
31 | )
32 | logging.getLogger("lightning.pytorch.accelerators.gpu").setLevel(
33 | logging.WARNING
34 | )
35 |
36 | fn_kwargs = {
37 | "model": pl_module,
38 | "datamodule": trainer.datamodule,
39 | "ckpt_path": ckpt_path,
40 | }
41 |
42 | val_metrics = {}
43 | if trainer.validate_loop._data_source.is_defined():
44 | trainer.validate(**fn_kwargs)
45 | val_metrics = convert_tensors_to_scalars(trainer.logged_metrics)
46 |
47 | test_metrics = {}
48 | if trainer.test_loop._data_source.is_defined():
49 | trainer.test(**fn_kwargs)
50 | test_metrics = convert_tensors_to_scalars(trainer.logged_metrics)
51 |
52 | metrics = {**val_metrics, **test_metrics}
53 | else:
54 | metrics = convert_tensors_to_scalars(trainer.logged_metrics)
55 |
56 | if metrics:
57 | metrics_str = json.dumps(metrics, ensure_ascii=False, indent=2)
58 |
59 | metrics_file = Path(trainer.log_dir) / "metrics.json"
60 | with metrics_file.open("w") as f:
61 | f.write(metrics_str)
62 |
--------------------------------------------------------------------------------
/src/datamodules/__init__.py:
--------------------------------------------------------------------------------
1 | from .glue_datamodule import GLUEDataModule
2 | from .mnist_datamodule import MNISTDataModule
3 |
4 | __all__ = ["GLUEDataModule", "MNISTDataModule"]
5 |
--------------------------------------------------------------------------------
/src/datamodules/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tshu-w/lightning-template/810508e9e345dde2f377bd3e8d6471ce585736c0/src/datamodules/datasets/__init__.py
--------------------------------------------------------------------------------
/src/datamodules/glue_datamodule.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from functools import partial
3 | from typing import Literal
4 |
5 | import lightning as L
6 | from datasets import load_dataset
7 | from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
8 | from torch.utils.data import DataLoader
9 |
10 | warnings.filterwarnings(
11 | "ignore", ".*Consider increasing the value of the `num_workers` argument*"
12 | )
13 |
14 | TASK_NAME = Literal[
15 | "cola",
16 | "sst2",
17 | "mrpc",
18 | "qqp",
19 | "stsb",
20 | "mnli",
21 | "qnli",
22 | "rte",
23 | "wnli",
24 | "ax",
25 | ]
26 |
27 |
28 | class GLUEDataModule(L.LightningDataModule):
29 | task_text_field_map = {
30 | "cola": ["sentence"],
31 | "sst2": ["sentence"],
32 | "mrpc": ["sentence1", "sentence2"],
33 | "qqp": ["question1", "question2"],
34 | "stsb": ["sentence1", "sentence2"],
35 | "mnli": ["premise", "hypothesis"],
36 | "qnli": ["question", "sentence"],
37 | "rte": ["sentence1", "sentence2"],
38 | "wnli": ["sentence1", "sentence2"],
39 | "ax": ["premise", "hypothesis"],
40 | }
41 |
42 | glue_task_num_labels = {
43 | "cola": 2,
44 | "sst2": 2,
45 | "mrpc": 2,
46 | "qqp": 2,
47 | "stsb": 1,
48 | "mnli": 3,
49 | "qnli": 2,
50 | "rte": 2,
51 | "wnli": 2,
52 | "ax": 3,
53 | }
54 |
55 | def __init__(
56 | self,
57 | task_name: TASK_NAME = "mrpc",
58 | batch_size: int = 32,
59 | num_workers: int = 0,
60 | pin_memory: bool = False,
61 | ):
62 | super().__init__()
63 | self.save_hyperparameters()
64 |
65 | self.task_name = task_name
66 | self.num_labels = self.glue_task_num_labels[task_name]
67 | self.text_fields = self.task_text_field_map[task_name]
68 |
69 | def prepare_data(self) -> None:
70 | # setup first to prevent datasets cache conflicts in multiple processes.
71 | self.setup()
72 |
73 | def setup(self, stage: str | None = None) -> None:
74 | if not hasattr(self, "datasets"):
75 | convert_to_features = self.trainer.model.convert_to_features
76 | preprocess_fn = partial(self._preprocess, text_fields=self.text_fields)
77 |
78 | def preprocess(x):
79 | return convert_to_features(preprocess_fn(x))
80 |
81 | datasets = load_dataset("glue", self.task_name)
82 | columns_names = self.text_fields + ["label", "idx"]
83 | self.datasets = datasets.map(
84 | preprocess,
85 | batched=True,
86 | remove_columns=columns_names,
87 | )
88 |
89 | self.datasets.set_format(type="torch")
90 |
91 | self.val_splits = [x for x in self.datasets if "validation" in x]
92 | self.test_splits = [x for x in self.datasets if "test" in x]
93 |
94 | self.collate_fn = getattr(self.trainer.model, "collate_fn", None)
95 |
96 | def train_dataloader(self) -> TRAIN_DATALOADERS:
97 | return DataLoader(
98 | dataset=self.datasets["train"],
99 | batch_size=self.hparams.batch_size,
100 | num_workers=self.hparams.num_workers,
101 | pin_memory=self.hparams.pin_memory,
102 | collate_fn=self.collate_fn,
103 | persistent_workers=self.hparams.num_workers > 0,
104 | shuffle=True,
105 | )
106 |
107 | def val_dataloader(self) -> EVAL_DATALOADERS:
108 | val_dataloaders = [
109 | DataLoader(
110 | dataset=self.datasets[x],
111 | batch_size=self.hparams.batch_size,
112 | num_workers=self.hparams.num_workers,
113 | pin_memory=self.hparams.pin_memory,
114 | collate_fn=self.collate_fn,
115 | persistent_workers=self.hparams.num_workers > 0,
116 | shuffle=False,
117 | )
118 | for x in self.val_splits
119 | ]
120 |
121 | return val_dataloaders[0] if len(val_dataloaders) == 1 else val_dataloaders
122 |
123 | def test_dataloader(self) -> EVAL_DATALOADERS:
124 | test_dataloaders = [
125 | DataLoader(
126 | dataset=self.datasets[x],
127 | batch_size=self.hparams.batch_size,
128 | num_workers=self.hparams.num_workers,
129 | pin_memory=self.hparams.pin_memory,
130 | collate_fn=self.collate_fn,
131 | persistent_workers=self.hparams.num_workers > 0,
132 | shuffle=False,
133 | )
134 | for x in self.test_splits
135 | ]
136 |
137 | return test_dataloaders[0] if len(test_dataloaders) == 1 else test_dataloaders
138 |
139 | @staticmethod
140 | def _preprocess(batch, text_fields):
141 | if len(text_fields) > 1:
142 | text = list(zip(batch[text_fields[0]], batch[text_fields[1]], strict=True))
143 | else:
144 | text = batch[text_fields[0]]
145 | labels = batch["label"]
146 |
147 | return {"text": text, "labels": labels}
148 |
--------------------------------------------------------------------------------
/src/datamodules/mnist_datamodule.py:
--------------------------------------------------------------------------------
1 | import lightning as L
2 | from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
3 | from torch.utils.data import DataLoader, random_split
4 | from torchvision.datasets import MNIST
5 | from torchvision.transforms import transforms
6 |
7 |
8 | class MNISTDataModule(L.LightningDataModule):
9 | def __init__(
10 | self,
11 | data_dir: str = "data/",
12 | batch_size: int = 64,
13 | num_workers: int = 0,
14 | pin_memory: bool = False,
15 | ):
16 | super().__init__()
17 | self.save_hyperparameters()
18 |
19 | self.transforms = transforms.ToTensor()
20 | self.data = {}
21 |
22 | def prepare_data(self) -> None:
23 | MNIST(self.hparams.data_dir, train=True, download=True)
24 | MNIST(self.hparams.data_dir, train=False, download=True)
25 |
26 | def setup(self, stage: str | None = None) -> None:
27 | if not self.data:
28 | dataset = MNIST(
29 | self.hparams.data_dir, train=True, transform=self.transforms
30 | )
31 | self.data["train"], self.data["val"] = random_split(dataset, [55000, 5000])
32 |
33 | self.data["test"] = MNIST(
34 | self.hparams.data_dir, train=False, transform=self.transforms
35 | )
36 |
37 | def train_dataloader(self) -> TRAIN_DATALOADERS:
38 | return DataLoader(
39 | dataset=self.data["train"],
40 | batch_size=self.hparams.batch_size,
41 | num_workers=self.hparams.num_workers,
42 | pin_memory=self.hparams.pin_memory,
43 | persistent_workers=self.hparams.num_workers > 0,
44 | shuffle=True,
45 | )
46 |
47 | def val_dataloader(self) -> EVAL_DATALOADERS:
48 | return DataLoader(
49 | dataset=self.data["val"],
50 | batch_size=self.hparams.batch_size,
51 | num_workers=self.hparams.num_workers,
52 | pin_memory=self.hparams.pin_memory,
53 | persistent_workers=self.hparams.num_workers > 0,
54 | shuffle=False,
55 | )
56 |
57 | def test_dataloader(self) -> EVAL_DATALOADERS:
58 | return DataLoader(
59 | dataset=self.data["test"],
60 | batch_size=self.hparams.batch_size,
61 | num_workers=self.hparams.num_workers,
62 | pin_memory=self.hparams.pin_memory,
63 | persistent_workers=self.hparams.num_workers > 0,
64 | shuffle=False,
65 | )
66 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .glue_transformer import GLUETransformer
2 | from .mnist_model import MNISTModel
3 |
4 | __all__ = ["GLUETransformer", "MNISTModel"]
5 |
--------------------------------------------------------------------------------
/src/models/glue_transformer.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Any
3 |
4 | import evaluate
5 | import lightning as L
6 | import torch
7 | from lightning.pytorch.utilities.types import STEP_OUTPUT
8 | from transformers import (
9 | AutoModelForSequenceClassification,
10 | AutoTokenizer,
11 | PreTrainedTokenizer,
12 | get_scheduler,
13 | )
14 |
15 |
16 | class GLUETransformer(L.LightningModule):
17 | def __init__(
18 | self,
19 | task_name: str,
20 | model_name_or_path: str,
21 | num_labels: int,
22 | max_length: int | None = None,
23 | weight_decay: float = 0.0,
24 | learning_rate: float = 2e-5,
25 | scheduler_type: str = "linear",
26 | warmup_steps: int = 0,
27 | ):
28 | super().__init__()
29 | self.save_hyperparameters()
30 |
31 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
32 | self.convert_to_features = partial(
33 | self._convert_to_features, tokenizer=tokenizer, max_length=max_length
34 | )
35 | self.model = AutoModelForSequenceClassification.from_pretrained(
36 | model_name_or_path
37 | )
38 | self.metric = evaluate.load("glue", task_name)
39 |
40 | self.validation_step_outputs = []
41 | self.test_step_outputs = []
42 |
43 | def forward(self, batch):
44 | return self.model.forward(**batch)
45 |
46 | def shared_step(self, batch) -> STEP_OUTPUT | None:
47 | output = self.forward(batch)
48 | loss, logits = output.loss, output.logits
49 | labels = batch["labels"]
50 |
51 | if self.hparams.num_labels >= 1:
52 | preds = torch.argmax(logits, dim=1)
53 | elif self.hparams.num_labels == 1:
54 | preds = logits.squeeze()
55 |
56 | return {"loss": loss, "preds": preds, "labels": labels}
57 |
58 | def training_step(
59 | self, batch, batch_idx: int, dataloader_idx: int | None = None
60 | ) -> STEP_OUTPUT:
61 | return self.shared_step(batch)
62 |
63 | def validation_step(
64 | self, batch, batch_idx: int, dataloader_idx: int | None = None
65 | ) -> STEP_OUTPUT | None:
66 | output = self.shared_step(batch)
67 | self.validation_step_outputs.append(output)
68 | return output
69 |
70 | def test_step(
71 | self, batch, batch_idx: int, dataloader_idx: int | None = None
72 | ) -> STEP_OUTPUT | None:
73 | output = self.shared_step(batch)
74 | self.test_step_outputs.append(output)
75 | return output
76 |
77 | def shared_epoch_end(self, outputs, step: str) -> None:
78 | if hasattr(self.trainer.datamodule, f"{step}_splits"):
79 | splits = getattr(self.trainer.datamodule, f"{step}_splits")
80 | if len(splits) > 1:
81 | for i, output in enumerate(outputs):
82 | split = splits[i].split("_")[-1]
83 | preds = torch.cat([x["preds"] for x in output])
84 | labels = torch.cat([x["labels"] for x in output])
85 | loss = torch.stack([x["loss"] for x in output]).mean()
86 |
87 | split_metrics = {
88 | f"{step}/{split}_{k}": v
89 | for k, v in self.metric.compute(
90 | predictions=preds, references=labels
91 | ).items()
92 | }
93 |
94 | self.log(f"{step}/{split}_loss", loss)
95 | self.log_dict(split_metrics, prog_bar=True)
96 |
97 | return loss
98 |
99 | preds = torch.cat([x["preds"] for x in outputs])
100 | labels = torch.cat([x["labels"] for x in outputs])
101 | loss = torch.stack([x["loss"] for x in outputs]).mean()
102 |
103 | metrics = {
104 | f"{step}/{k}": v
105 | for k, v in self.metric.compute(
106 | predictions=preds, references=labels
107 | ).items()
108 | }
109 |
110 | self.log(f"{step}/loss", loss)
111 | self.log_dict(metrics, prog_bar=True)
112 |
113 | def on_training_epoch_end(self) -> None:
114 | self.shared_epoch_end(self.training_step_outputs, "train")
115 |
116 | def on_validation_epoch_end(self) -> None:
117 | self.shared_epoch_end(self.validation_step_outputs, "val")
118 | self.validation_step_outputs.clear()
119 |
120 | def on_test_epoch_end(self) -> None:
121 | self.shared_epoch_end(self.test_step_outputs, "test")
122 | self.test_step_outputs.clear()
123 |
124 | def configure_optimizers(self):
125 | no_decay = ["bias", "LayerNorm.weight"]
126 | optimizer_grouped_parameters = [
127 | {
128 | "params": [
129 | p
130 | for n, p in self.named_parameters()
131 | if not any(nd in n for nd in no_decay)
132 | ],
133 | "weight_decay": self.hparams.weight_decay,
134 | },
135 | {
136 | "params": [
137 | p
138 | for n, p in self.named_parameters()
139 | if any(nd in n for nd in no_decay)
140 | ],
141 | "weight_decay": 0.0,
142 | },
143 | ]
144 | optimizer = torch.optim.AdamW(
145 | optimizer_grouped_parameters,
146 | lr=self.hparams.learning_rate,
147 | )
148 |
149 | scheduler = get_scheduler(
150 | self.hparams.scheduler_type,
151 | optimizer,
152 | num_warmup_steps=self.hparams.warmup_steps,
153 | num_training_steps=self.trainer.estimated_stepping_batches,
154 | )
155 | scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
156 |
157 | return [optimizer], [scheduler]
158 |
159 | @staticmethod
160 | def _convert_to_features(
161 | batch: dict[str, list] | list[Any],
162 | tokenizer: PreTrainedTokenizer,
163 | max_length: int | None = None,
164 | ) -> dict | Any:
165 | features = tokenizer(
166 | batch["text"],
167 | padding="max_length",
168 | truncation=True,
169 | max_length=max_length,
170 | )
171 | features["labels"] = batch["labels"]
172 |
173 | return features
174 |
--------------------------------------------------------------------------------
/src/models/mnist_model.py:
--------------------------------------------------------------------------------
1 | import lightning as L
2 | import torch
3 | import torch.nn.functional as F
4 | from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
5 | from lightning.pytorch.utilities.types import STEP_OUTPUT
6 | from torchmetrics import Accuracy, MetricCollection
7 |
8 |
9 | class MNISTModel(L.LightningModule):
10 | def __init__(
11 | self,
12 | input_size: int = 28 * 28,
13 | hidden_dim: int = 128,
14 | output_size: int = 10,
15 | learning_rate: float = 1e-3,
16 | ):
17 | super().__init__()
18 | self.save_hyperparameters()
19 |
20 | self.l1 = torch.nn.Linear(input_size, hidden_dim)
21 | self.l2 = torch.nn.Linear(hidden_dim, output_size)
22 |
23 | metrics = MetricCollection({"acc": Accuracy(task="multiclass", num_classes=10)})
24 | self.train_metrics = metrics.clone(prefix="train/")
25 | self.val_metrics = metrics.clone(prefix="val/")
26 | self.test_metrics = metrics.clone(prefix="test/")
27 |
28 | def forward(self, x):
29 | x = x.view(x.size(0), -1)
30 | x = torch.relu(self.l1(x))
31 | x = torch.relu(self.l2(x))
32 |
33 | return x
34 |
35 | def shared_step(self, batch, step: str) -> STEP_OUTPUT | None:
36 | x, y = batch
37 | logits = self.forward(x)
38 | loss = F.cross_entropy(logits, y)
39 |
40 | preds = torch.argmax(logits, dim=1)
41 | metrics = getattr(self, f"{step}_metrics")
42 | metrics(preds, y)
43 |
44 | self.log(f"{step}/loss", loss)
45 | self.log_dict(metrics, prog_bar=True)
46 |
47 | return loss
48 |
49 | def training_step(self, batch, batch_idx: int) -> STEP_OUTPUT:
50 | return self.shared_step(batch, "train")
51 |
52 | def validation_step(self, batch, batch_idx: int) -> STEP_OUTPUT | None:
53 | return self.shared_step(batch, "val")
54 |
55 | def test_step(self, batch, batch_idx: int) -> STEP_OUTPUT | None:
56 | return self.shared_step(batch, "test")
57 |
58 | def configure_optimizers(self):
59 | return torch.optim.Adam(params=self.parameters(), lr=self.hparams.learning_rate)
60 |
61 | def configure_callbacks(self):
62 | callbacks_kargs = {"monitor": "val/acc", "mode": "max"}
63 | early_stopping = EarlyStopping(patience=5, **callbacks_kargs)
64 | model_checkpoint = ModelCheckpoint(**callbacks_kargs)
65 | return [early_stopping, model_checkpoint]
66 |
--------------------------------------------------------------------------------
/src/models/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tshu-w/lightning-template/810508e9e345dde2f377bd3e8d6471ce585736c0/src/models/modules/__init__.py
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tshu-w/lightning-template/810508e9e345dde2f377bd3e8d6471ce585736c0/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/lit_cli.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections.abc import Iterable
3 |
4 | from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
5 |
6 |
7 | class LitCLI(LightningCLI):
8 | def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
9 | for arg in ["num_labels", "task_name"]:
10 | parser.link_arguments(
11 | f"data.init_args.{arg}",
12 | f"model.init_args.{arg}",
13 | apply_on="instantiate",
14 | )
15 |
16 | def before_instantiate_classes(self) -> None:
17 | config = self.config[self.subcommand]
18 |
19 | default_root_dir = config.trainer.default_root_dir
20 | logger = config.trainer.logger
21 | if logger and logger is not True:
22 | loggers = logger if isinstance(logger, Iterable) else [logger]
23 | for logger in loggers:
24 | logger.init_args.save_dir = os.path.join(
25 | default_root_dir, self.subcommand
26 | )
27 |
28 |
29 | def lit_cli():
30 | LitCLI(
31 | parser_kwargs={
32 | cmd: {
33 | "default_config_files": ["configs/presets/default.yaml"],
34 | }
35 | for cmd in ["fit", "validate", "test"]
36 | },
37 | save_config_kwargs={"overwrite": True},
38 | )
39 |
40 |
41 | if __name__ == "__main__":
42 | lit_cli()
43 |
--------------------------------------------------------------------------------
/src/utils/loggers.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import lightning as L
4 | from lightning.fabric.utilities.types import _PATH
5 | from lightning.pytorch.callbacks import ModelCheckpoint
6 |
7 |
8 | # TODO:
9 | # https://github.com/Lightning-AI/lightning/issues/14188
10 | # https://github.com/Lightning-AI/lightning/pull/14640
11 | @property
12 | def log_dir(self) -> str:
13 | if self.loggers and self.loggers[0].log_dir is not None:
14 | dirpath = self.loggers[0].log_dir
15 | else:
16 | dirpath = self.default_root_dir
17 |
18 | dirpath = self.strategy.broadcast(dirpath)
19 | return dirpath
20 |
21 |
22 | L.Trainer.log_dir = log_dir
23 |
24 |
25 | def __resolve_ckpt_dir(self, trainer: L.Trainer) -> _PATH:
26 | """Determines model checkpoint save directory at runtime. References attributes from the trainer's logger
27 | to determine where to save checkpoints. The base path for saving weights is set in this priority:
28 | 1. Checkpoint callback's path (if passed in)
29 | 2. The default_root_dir from trainer if trainer has no logger
30 | 3. The log_dir from trainer, if trainer has logger
31 | """
32 | if self.dirpath is not None:
33 | # short circuit if dirpath was passed to ModelCheckpoint
34 | return self.dirpath
35 | if trainer.loggers:
36 | ckpt_path = os.path.join(trainer.log_dir, "checkpoints")
37 | else:
38 | ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")
39 | return ckpt_path
40 |
41 |
42 | ModelCheckpoint._ModelCheckpoint__resolve_ckpt_dir = __resolve_ckpt_dir
43 |
--------------------------------------------------------------------------------
/src/utils/sweep_cli.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import json
3 | import math
4 | import os
5 | import shlex
6 | import subprocess
7 | from pathlib import Path
8 | from typing import Any, Literal
9 |
10 | import ray
11 | from jsonargparse import CLI
12 | from ray import train, tune
13 |
14 | ray.init(_temp_dir=str(Path.home() / ".cache" / "ray"))
15 |
16 |
17 | def run_cli(config, debug: bool = True, command: str = "fit", devices: int = 1):
18 | os.chdir(os.environ["TUNE_ORIG_WORKING_DIR"])
19 |
20 | argv = ["./run", command]
21 | ckpt_path = config.pop("ckpt_path", None)
22 | if ckpt_path is not None:
23 | config_path = Path(ckpt_path).parents[1] / "config.yaml"
24 | argv.extend(["--config", str(config_path)])
25 | argv.extend(["--ckpt_path", ckpt_path])
26 | config.pop("config", None)
27 | config.pop("data_config", None)
28 | else:
29 | for cfg in ["config", "data_config"]:
30 | if cfg in config:
31 | argv.extend(["--config", config.pop(cfg)])
32 |
33 | argv.extend(
34 | itertools.chain(
35 | *[
36 | [f"--{k}", v if isinstance(v, str) else json.dumps(v)]
37 | for k, v in config.items()
38 | ]
39 | )
40 | )
41 |
42 | argv.extend(["--trainer.devices", str(devices)])
43 | if debug:
44 | argv.extend(["--config", "configs/presets/tester.yaml"])
45 |
46 | print(shlex.join(argv))
47 | subprocess.check_output(argv)
48 |
49 |
50 | def sweep(
51 | command: Literal["fit", "validate", "test"],
52 | debug: bool = False,
53 | gpus_per_trial: int | float = 1,
54 | *,
55 | ckpt_paths: list[str | None] | None = None,
56 | configs: list[str] | None = None,
57 | data_configs: list[str | None] | None = None,
58 | override_kwargs: dict[str, Any] | None = None,
59 | ):
60 | param_space = {
61 | **({"ckpt_path": tune.grid_search(ckpt_paths)} if ckpt_paths else {}),
62 | **({"config": tune.grid_search(configs)} if configs else {}),
63 | **({"data_config": tune.grid_search(data_configs)} if data_configs else {}),
64 | **(
65 | {
66 | k: tune.grid_search(v) if isinstance(v, list) else tune.grid_search([v])
67 | for k, v in override_kwargs.items()
68 | }
69 | if override_kwargs
70 | else {}
71 | ),
72 | }
73 |
74 | tune_config = tune.TuneConfig()
75 | run_config = train.RunConfig(
76 | log_to_file=True,
77 | storage_path=Path("./results/ray").resolve(),
78 | )
79 | trainable = tune.with_parameters(
80 | run_cli,
81 | debug=debug,
82 | command=command,
83 | devices=math.ceil(gpus_per_trial),
84 | )
85 | tuner = tune.Tuner(
86 | tune.with_resources(trainable, resources={"gpu": gpus_per_trial}),
87 | param_space=param_space,
88 | tune_config=tune_config,
89 | run_config=run_config,
90 | )
91 | tuner.fit()
92 |
93 |
94 | def fit(*args, **kwargs):
95 | sweep("fit", *args, **kwargs)
96 |
97 |
98 | def validate(*args, **kwargs):
99 | sweep("validate", *args, **kwargs)
100 |
101 |
102 | def test(*args, **kwargs):
103 | sweep("test", *args, **kwargs)
104 |
105 |
106 | def sweep_cli():
107 | CLI([fit, validate, test])
108 |
109 |
110 | if __name__ == "__main__":
111 | sweep_cli()
112 |
--------------------------------------------------------------------------------
/src/vendor/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tshu-w/lightning-template/810508e9e345dde2f377bd3e8d6471ce585736c0/src/vendor/__init__.py
--------------------------------------------------------------------------------