├── data └── .gitkeep ├── logs └── .gitkeep ├── notebooks └── .gitkeep ├── src ├── __init__.py ├── metrics │ └── __init__.py ├── callbacks │ └── __init__.py ├── models │ ├── __init__.py │ └── simplenet.py ├── utils │ ├── __init__.py │ └── cli.py ├── data_modules │ ├── __init__.py │ └── mnist.py ├── modules │ ├── __init__.py │ └── mnist_module.py └── main.py ├── .python-version ├── .project-root ├── .vscode ├── extensions.json ├── settings.json └── launch.json ├── configs ├── data │ └── mnist.yaml ├── model │ └── simplenet.yaml └── default.yaml ├── .env.example ├── LICENSE ├── pyproject.toml ├── scripts └── clear_wandb_cache.py ├── README_PROJECT.md ├── .gitignore ├── README_ZH.md └── README.md /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /src/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- 1 | # This is the indicator file for pyrootutils 2 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from src.models.simplenet import SimpleNet 2 | 3 | __all__ = [ 4 | "SimpleNet", 5 | ] 6 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from src.utils.cli import CustomLightningCLI 2 | 3 | __all__ = [ 4 | "CustomLightningCLI", 5 | ] 6 | -------------------------------------------------------------------------------- /src/data_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from src.data_modules.mnist import MNISTDataModule 2 | 3 | __all__ = [ 4 | "MNISTDataModule", 5 | ] 6 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from src.modules.mnist_module import SimpleNetModule 2 | 3 | __all__ = [ 4 | "SimpleNetModule", 5 | ] 6 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-python.python", 4 | "ms-toolsai.jupyter", 5 | "charliermarsh.ruff", 6 | "tamasfe.even-better-toml" 7 | ] 8 | } -------------------------------------------------------------------------------- /configs/data/mnist.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: data_modules.MNISTDataModule 3 | init_args: 4 | data_dir: data 5 | train_val_test_split: 6 | - 55000 7 | - 5000 8 | - 10000 9 | batch_size: 64 10 | num_workers: 2 11 | pin_memory: true 12 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import autorootcwd # noqa: F401 4 | 5 | from src.utils import CustomLightningCLI # noqa: E402 6 | 7 | if __name__ == "__main__": 8 | if os.environ.get("DEBUG", False): 9 | import debugpy 10 | 11 | debugpy.listen(5678) 12 | print("Waiting for debugger attach") 13 | debugpy.wait_for_client() 14 | 15 | cli = CustomLightningCLI() 16 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # This file is loaded by the pyrootutils (python-dotenv) 2 | # Remove .example if you need to use it 3 | 4 | # Limit numpy threads in case they take too many cores 5 | OMP_NUM_THREADS=8 6 | MKL_NUM_THREADS=8 7 | GOTO_NUM_THREADS=8 8 | NUMEXPR_NUM_THREADS=8 9 | OPENBLAS_NUM_THREADS=8 10 | MKL_DOMAIN_NUM_THREADS=8 11 | VECLIB_MAXIMUM_THREADS=8 12 | 13 | # CLI Troubleshooting, https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_faq.html#how-do-i-troubleshoot-a-cli 14 | # JSONARGPARSE_DEBUG=true 15 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.formatOnSave": true, 3 | "[python]": { 4 | "editor.codeActionsOnSave": { 5 | "source.fixAll": "explicit", 6 | "source.organizeImports": "explicit" 7 | }, 8 | "editor.defaultFormatter": "charliermarsh.ruff" 9 | }, 10 | "notebook.formatOnSave.enabled": true, 11 | "notebook.codeActionsOnSave": { 12 | "notebook.source.fixAll": "explicit", 13 | "notebook.source.organizeImports": "explicit" 14 | }, 15 | "editor.rulers": [ 16 | 120 17 | ] 18 | } -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "Python: Remote Debug", 6 | "type": "debugpy", 7 | "request": "attach", 8 | "listen": { 9 | "host": "0.0.0.0", 10 | "port": 5678 11 | }, 12 | "pathMappings": [ 13 | { 14 | "localRoot": "${workspaceFolder}", 15 | "remoteRoot": "." 16 | } 17 | ], 18 | "justMyCode": true 19 | } 20 | ] 21 | } -------------------------------------------------------------------------------- /configs/model/simplenet.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | class_path: modules.SimpleNetModule 3 | # https://pytorch.org/docs/stable/optim.html 4 | optimizer: 5 | class_path: AdamW 6 | init_args: 7 | lr: 0.001 8 | weight_decay: 1e-6 9 | # https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate 10 | lr_scheduler: 11 | class_path: ReduceLROnPlateau 12 | init_args: 13 | monitor: val/loss 14 | mode: min 15 | factor: 0.1 16 | patience: 3 17 | verbose: True 18 | threshold: 0.0001 19 | threshold_mode: rel 20 | cooldown: 0 21 | min_lr: 0.00001 22 | eps: 1e-08 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /src/models/simplenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class SimpleNet(nn.Module): 6 | def __init__(self, num_classes=10): 7 | super().__init__() 8 | self.conv = nn.Sequential( 9 | nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0), 10 | nn.BatchNorm2d(6), 11 | nn.ReLU(), 12 | nn.MaxPool2d(kernel_size=2, stride=2), 13 | nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0), 14 | nn.BatchNorm2d(16), 15 | nn.ReLU(), 16 | nn.MaxPool2d(kernel_size=2, stride=2), 17 | ) 18 | self.head = nn.Sequential( 19 | nn.Flatten(), nn.Linear(256, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, num_classes) 20 | ) 21 | 22 | def forward(self, x): 23 | x = self.conv(x) 24 | x = self.head(x) 25 | return x 26 | 27 | 28 | if __name__ == "__main__": 29 | model = SimpleNet(num_classes=10) 30 | for name, param in model.named_parameters(): 31 | print(name, param.shape) 32 | data = torch.randn(10, 1, 28, 28) 33 | print(model(data).shape) 34 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "pytorch-lightning-template" 3 | version = "0.1.0" 4 | description = "A template for simple deep learning projects using Lightning" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "autorootcwd>=1.0.1", 9 | "debugpy>=1.8.17", 10 | "flash-attn>=2.8.3", 11 | "ipykernel>=7.1.0", 12 | "ipywidgets>=8.1.8", 13 | "jsonargparse[signatures]>=4.43.0", 14 | "matplotlib>=3.10.7", 15 | "opencv-python>=4.12.0.88", 16 | "pandas>=2.3.3", 17 | "pytorch-lightning>=2.5.6", 18 | "pytorch3d>=0.7.8", 19 | "torch>=2.9.1", 20 | "torchmetrics>=1.8.2", 21 | "torchvision>=0.24.1", 22 | "tqdm>=4.67.1", 23 | "transformers>=4.57.2", 24 | "wandb>=0.23.0", 25 | ] 26 | 27 | [tool.uv.sources] 28 | torch = [{ index = "pytorch-cu128" }] 29 | torchvision = [{ index = "pytorch-cu128" }] 30 | flash-attn = [{ index = "torch_packages_builder" }] 31 | pytorch3d = [{ index = "torch_packages_builder" }] 32 | 33 | [[tool.uv.index]] 34 | name = "pytorch-cu128" 35 | url = "https://download.pytorch.org/whl/cu128" 36 | explicit = true 37 | 38 | [[tool.uv.index]] 39 | name = "torch_packages_builder" 40 | url = "https://miropsota.github.io/torch_packages_builder" 41 | explicit = true 42 | 43 | # https://docs.astral.sh/ruff/configuration/ 44 | [tool.ruff] 45 | line-length = 120 46 | select = [ 47 | "I", # isort 48 | "F", # pyflakes 49 | "E", # pycodestyle errors 50 | "W", # pycodestyle warnings 51 | "UP", # pyupgrade 52 | ] 53 | ignore = ["F401"] 54 | 55 | [tool.pyright] 56 | exclude = ["**/__pycache__", "data", "logs"] 57 | typeCheckingMode = "off" 58 | -------------------------------------------------------------------------------- /scripts/clear_wandb_cache.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is used to delete unused offline wandb runs and model checkpoints. 3 | """ 4 | 5 | import os 6 | import shutil 7 | 8 | import autorootcwd # noqa: F401 9 | import wandb 10 | 11 | WANDB_PROJECT = "mnist" 12 | MODEL_CHECKPOINT_PATH = os.path.join("logs", "mnist") 13 | LOG_PATH = os.path.join("logs", "wandb") 14 | WHITE_LIST = [ 15 | "", 16 | ] 17 | 18 | if __name__ == "__main__": 19 | # get run logs 20 | run_id_to_log_path_map = {} 21 | if os.path.exists(LOG_PATH): 22 | for dir in os.listdir(LOG_PATH): 23 | log_path = os.path.join(LOG_PATH, dir) 24 | if os.path.isdir(log_path) and not os.path.islink(log_path): 25 | version = dir.split("-")[-1] 26 | run_id_to_log_path_map[version] = log_path 27 | 28 | # get model checkpoints 29 | run_id_to_model_checkpoint_path_map = {} 30 | if os.path.exists(MODEL_CHECKPOINT_PATH): 31 | for version in os.listdir(MODEL_CHECKPOINT_PATH): 32 | model_checkpoint_path = os.path.join(MODEL_CHECKPOINT_PATH, version) 33 | if os.path.isdir(model_checkpoint_path): 34 | run_id_to_model_checkpoint_path_map[version] = model_checkpoint_path 35 | 36 | # get online runs 37 | api = wandb.Api() 38 | runs = api.runs(WANDB_PROJECT) 39 | online_run_id_to_name_map = {} 40 | for run in runs: 41 | online_run_id_to_name_map[run.id] = run.name 42 | 43 | # delete offline runs 44 | keep_count = 0 45 | delete_count = 0 46 | for run_id, log_path in run_id_to_log_path_map.items(): 47 | if run_id not in online_run_id_to_name_map: 48 | # print(f"Deleting {run_id}") 49 | shutil.rmtree(log_path) 50 | delete_count += 1 51 | else: 52 | # print(f"Keeping {online_run_id_to_name_map[run_id]}({run_id})") 53 | keep_count += 1 54 | print(f"{keep_count} runs kept and {delete_count} runs deleted") 55 | 56 | # delete offline model checkpoints 57 | keep_count = 0 58 | delete_count = 0 59 | for run_id, model_checkpoint_path in run_id_to_model_checkpoint_path_map.items(): 60 | if run_id not in online_run_id_to_name_map: 61 | if run_id in WHITE_LIST: 62 | continue 63 | # print(f"Deleting {run_id}") 64 | shutil.rmtree(model_checkpoint_path) 65 | delete_count += 1 66 | else: 67 | # print(f"Keeping {online_run_id_to_name_map[run_id]}({run_id})") 68 | keep_count += 1 69 | print(f"{keep_count} model checkpoints kept and {delete_count} model checkpoints deleted") 70 | -------------------------------------------------------------------------------- /README_PROJECT.md: -------------------------------------------------------------------------------- 1 |
2 |

Paper Title

3 |
4 | First Author1; 5 | Second Author2; 6 | Third Author3; 7 |
8 | 1First Affiliation 9 | 2Second Affiliation 10 | 3Third Affiliation 11 |
12 |
13 |
14 | ArXiv 15 | Template 16 | License: MIT 17 |
18 |
19 |
20 | 21 | ![Teaser](https://placehold.co/600x300@2x.png?text=Teaser) 22 | 23 | ## Abstract 24 | 25 | Lorem ipsum dolor sit amet, consectetur adipiscing elit. Curabitur condimentum bibendum nulla in porta. Fusce id eros diam. Aenean ut egestas tortor, at eleifend felis. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Phasellus ut ante sit amet lorem commodo fringilla. Nulla rhoncus tincidunt erat vitae pulvinar. Vestibulum dapibus, mauris sed pharetra pellentesque, risus tellus fringilla erat, sit amet egestas urna risus eget nulla. Vestibulum porta, mauris sed commodo dictum, metus ex tempor sapien, eget viverra odio enim quis felis. Nunc aliquam nisi non nisl eleifend rhoncus. Duis sit amet mollis libero, porta hendrerit nunc. Maecenas ultricies sapien ultricies, tristique metus eu, blandit felis. 26 | 27 | ## Method 28 | 29 | ![Method](https://placehold.co/600x300@2x.png?text=Method) 30 | 31 | ## Prerequisites 32 | 33 | ### Installation 34 | 35 | ```bash 36 | # clone project 37 | git clone https://github.com/YOUR_GITHUB_NAME/YOUR_PROJECT_NAME.git 38 | 39 | # install uv, https://docs.astral.sh/uv/getting-started/installation/ 40 | curl -LsSf https://astral.sh/uv/install.sh | sh 41 | 42 | # install dependencies 43 | uv sync 44 | ``` 45 | 46 | ### Data Preparation 47 | 48 | ### Training 49 | 50 | ```bash 51 | python src/main.py fit --config configs/data/mnist.yaml --config configs/model/simplenet.yaml --trainer.logger.init_args.name exp1 52 | ``` 53 | 54 | ### Inference 55 | 56 | ```bash 57 | python src/main.py predict --config configs/data/mnist.yaml --config configs/model/simplenet.yaml --trainer.logger.init_args.name exp1 58 | ``` 59 | 60 | ## Citation 61 | 62 | ``` 63 | @inproceedings{Key, 64 | title={Your Title}, 65 | author={Your team}, 66 | booktitle={Venue}, 67 | year={Year} 68 | } 69 | ``` 70 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | # Global Seed # https://lightning.ai/docs/pytorch/stable/common/trainer.html#reproducibility 2 | seed_everything: 3407 3 | # Custom 4 | ignore_warnings: false # Ignore warnings 5 | test_after_fit: false # Apply test after fit finished 6 | git_commit_before_fit: false # Commit before fit 7 | # Trainer Config https://lightning.ai/docs/pytorch/stable/common/trainer.html 8 | trainer: 9 | # Train, Validate, Test and Predict 10 | max_epochs: -1 11 | min_epochs: null 12 | max_steps: -1 13 | min_steps: null 14 | max_time: null 15 | 16 | num_sanity_val_steps: 2 17 | check_val_every_n_epoch: 1 18 | val_check_interval: null 19 | overfit_batches: 0.0 20 | 21 | limit_train_batches: null 22 | limit_val_batches: null 23 | limit_test_batches: null 24 | limit_predict_batches: null 25 | 26 | # Device https://lightning.ai/docs/pytorch/stable/common/trainer.html#accelerator 27 | accelerator: gpu # "cpu", "gpu", "tpu", "ipu", "auto" 28 | devices: 1 # "2," for device id 2 29 | num_nodes: 1 # https://lightning.ai/docs/pytorch/stable/common/trainer.html#num-nodes 30 | 31 | # Distributed 32 | strategy: auto # https://lightning.ai/docs/pytorch/stable/common/trainer.html#strategy 33 | sync_batchnorm: false # https://lightning.ai/docs/pytorch/stable/common/trainer.html#sync-batchnorm 34 | use_distributed_sampler: true # https://lightning.ai/docs/pytorch/stable/common/trainer.html#lightning.pytorch.trainer.Trainer.params.use_distributed_sampler 35 | 36 | # Logger https://lightning.ai/docs/pytorch/latest/visualize/loggers.html 37 | # https://lightning.ai/docs/pytorch/latest/api_references.html#loggers 38 | logger: 39 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.loggers.wandb.html#module-lightning.pytorch.loggers.wandb 40 | class_path: WandbLogger 41 | init_args: 42 | save_dir: logs 43 | project: mnist 44 | log_model: false 45 | log_every_n_steps: 50 46 | 47 | # Callbacks https://lightning.ai/docs/pytorch/latest/extensions/callbacks.html 48 | callbacks: 49 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.ModelCheckpoint.html#lightning.pytorch.callbacks.ModelCheckpoint 50 | - class_path: ModelCheckpoint 51 | init_args: 52 | filename: epoch={epoch:02d}-val_acc={val/acc:.4f} 53 | monitor: val/acc 54 | verbose: true 55 | save_last: true 56 | save_top_k: 2 57 | mode: max 58 | auto_insert_metric_name: false 59 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping 60 | - class_path: EarlyStopping 61 | init_args: 62 | monitor: val/acc 63 | min_delta: 0.01 64 | patience: 5 65 | verbose: true 66 | mode: max 67 | strict: true 68 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.LearningRateMonitor.html#lightning.pytorch.callbacks.LearningRateMonitor 69 | - class_path: LearningRateMonitor 70 | init_args: 71 | logging_interval: epoch 72 | 73 | # Gradient Clipping https://lightning.ai/docs/pytorch/stable/common/trainer.html#gradient-clip-val 74 | gradient_clip_val: null 75 | gradient_clip_algorithm: null 76 | 77 | # Gradient Accumulation https://lightning.ai/docs/pytorch/stable/common/trainer.html#accumulate-grad-batches 78 | accumulate_grad_batches: 1 79 | 80 | # Precision https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision 81 | precision: 32-true 82 | 83 | # Plugins https://lightning.ai/docs/pytorch/stable/common/trainer.html#plugins 84 | plugins: null 85 | 86 | # Debug 87 | fast_dev_run: false # https://lightning.ai/docs/pytorch/stable/common/trainer.html#fast-dev-run 88 | profiler: null # https://lightning.ai/docs/pytorch/stable/api_references.html#profiler 89 | barebones: false 90 | detect_anomaly: false 91 | reload_dataloaders_every_n_epochs: 0 92 | 93 | # Misc 94 | inference_mode: true 95 | default_root_dir: null 96 | benchmark: null 97 | deterministic: false 98 | enable_progress_bar: true 99 | enable_checkpointing: true 100 | enable_model_summary: true 101 | -------------------------------------------------------------------------------- /src/data_modules/mnist.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://github.com/ashleve/lightning-hydra-template/blob/main/src/data/mnist_datamodule.py 2 | from typing import Any 3 | 4 | import pytorch_lightning as pl 5 | from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split 6 | from torchvision.datasets import MNIST 7 | from torchvision.transforms import transforms as T 8 | 9 | 10 | class MNISTDataModule(pl.LightningDataModule): 11 | """Example of LightningDataModule for MNIST dataset. 12 | A DataModule implements 5 key methods: 13 | def prepare_data(self): 14 | # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) 15 | # download data, pre-process, split, save to disk, etc... 16 | def setup(self, stage): 17 | # things to do on every process in DDP 18 | # load data, set variables, etc... 19 | def train_dataloader(self): 20 | # return train dataloader 21 | def val_dataloader(self): 22 | # return validation dataloader 23 | def test_dataloader(self): 24 | # return test dataloader 25 | def teardown(self): 26 | # called on every process in DDP 27 | # clean up after fit or test 28 | This allows you to share a full dataset without explaining how to download, 29 | split, transform and process the data. 30 | Read the docs: 31 | https://lightning.ai/docs/pytorch/latest/data/datamodule.html 32 | """ 33 | 34 | def __init__( 35 | self, 36 | data_dir: str = "data", 37 | train_val_test_split: tuple[int, int, int] = (55_000, 5_000, 10_000), 38 | batch_size: int = 64, 39 | num_workers: int = 0, 40 | pin_memory: bool = False, 41 | ): 42 | super().__init__() 43 | 44 | # this line allows to access init params with 'self.hparams' attribute 45 | # also ensures init params will be stored in ckpt 46 | self.save_hyperparameters() 47 | 48 | # data transformations 49 | self.transforms = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) 50 | 51 | self.data_train: Dataset = None 52 | self.data_val: Dataset = None 53 | self.data_test: Dataset = None 54 | 55 | @property 56 | def num_classes(self): 57 | return 10 58 | 59 | def prepare_data(self): 60 | """Download data if needed. 61 | Do not use it to assign state (self.x = y). 62 | """ 63 | MNIST(self.hparams.data_dir, train=True, download=True) 64 | MNIST(self.hparams.data_dir, train=False, download=True) 65 | 66 | def setup(self, stage: str = None): 67 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 68 | This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be 69 | careful not to execute things like random split twice! 70 | """ 71 | # load and split datasets only if not loaded already 72 | if not self.data_train and not self.data_val and not self.data_test: 73 | trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms) 74 | testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms) 75 | dataset = ConcatDataset(datasets=[trainset, testset]) 76 | self.data_train, self.data_val, self.data_test = random_split( 77 | dataset=dataset, 78 | lengths=self.hparams.train_val_test_split, 79 | ) 80 | 81 | def train_dataloader(self): 82 | return DataLoader( 83 | dataset=self.data_train, 84 | batch_size=self.hparams.batch_size, 85 | num_workers=self.hparams.num_workers, 86 | pin_memory=self.hparams.pin_memory, 87 | shuffle=True, 88 | ) 89 | 90 | def val_dataloader(self): 91 | return DataLoader( 92 | dataset=self.data_val, 93 | batch_size=self.hparams.batch_size, 94 | num_workers=self.hparams.num_workers, 95 | pin_memory=self.hparams.pin_memory, 96 | shuffle=False, 97 | ) 98 | 99 | def test_dataloader(self): 100 | return DataLoader( 101 | dataset=self.data_test, 102 | batch_size=self.hparams.batch_size, 103 | num_workers=self.hparams.num_workers, 104 | pin_memory=self.hparams.pin_memory, 105 | shuffle=False, 106 | ) 107 | 108 | def teardown(self, stage: str = None): 109 | """Clean up after fit or test.""" 110 | pass 111 | 112 | def state_dict(self): 113 | """Extra things to save to checkpoint.""" 114 | return {} 115 | 116 | def load_state_dict(self, state_dict: dict[str, Any]): 117 | """Things to do when loading checkpoint.""" 118 | pass 119 | 120 | 121 | if __name__ == "__main__": 122 | dm = MNISTDataModule() 123 | dm.prepare_data() 124 | dm.setup() 125 | for batch in dm.train_dataloader(): 126 | print(batch[0].shape) 127 | print(batch[1].shape) 128 | break 129 | -------------------------------------------------------------------------------- /src/modules/mnist_module.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://github.com/ashleve/lightning-hydra-template/blob/main/src/models/mnist_module.py 2 | from typing import Any 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | from torchmetrics import MaxMetric, MeanMetric 7 | from torchmetrics.classification.accuracy import Accuracy 8 | 9 | from src.models import SimpleNet 10 | 11 | 12 | class SimpleNetModule(pl.LightningModule): 13 | """Example of LightningModule for MNIST classification. 14 | A LightningModule organizes your PyTorch code into 6 sections: 15 | - Computations (init) 16 | - Train loop (training_step) 17 | - Validation loop (validation_step) 18 | - Test loop (test_step) 19 | - Prediction Loop (predict_step) 20 | - Optimizers and LR Schedulers (configure_optimizers) 21 | Docs: 22 | https://lightning.ai/docs/pytorch/latest/common/lightning_module.html 23 | """ 24 | 25 | def __init__( 26 | self, 27 | ): 28 | super().__init__() 29 | 30 | # this line allows to access init params with 'self.hparams' attribute 31 | # also ensures init params will be stored in ckpt 32 | self.save_hyperparameters() 33 | 34 | self.net = SimpleNet() 35 | 36 | # loss function 37 | self.criterion = torch.nn.CrossEntropyLoss() 38 | 39 | # metric objects for calculating and averaging accuracy across batches 40 | metric = Accuracy(task="multiclass", num_classes=10) 41 | self.train_acc = metric.clone() 42 | self.val_acc = metric.clone() 43 | self.test_acc = metric.clone() 44 | 45 | # for averaging loss across batches 46 | loss_metric = MeanMetric() 47 | self.train_loss = loss_metric.clone() 48 | self.val_loss = loss_metric.clone() 49 | self.test_loss = loss_metric.clone() 50 | 51 | # for tracking best so far validation accuracy 52 | self.val_acc_best = MaxMetric() 53 | 54 | def forward(self, x: torch.Tensor): 55 | return self.net(x) 56 | 57 | def on_train_start(self): 58 | # by default lightning executes validation step sanity checks before training starts, 59 | # so we need to make sure val_acc_best doesn't store accuracy from these checks 60 | self.val_acc_best.reset() 61 | 62 | def model_step(self, batch: Any): 63 | x, y = batch 64 | logits = self.forward(x) 65 | loss = self.criterion(logits, y) 66 | preds = torch.argmax(logits, dim=1) 67 | return loss, preds, y 68 | 69 | def training_step(self, batch: Any, batch_idx: int): 70 | loss, preds, targets = self.model_step(batch) 71 | 72 | # update and log metrics 73 | self.train_loss(loss) 74 | self.train_acc(preds, targets) 75 | self.log("train/loss", self.train_loss, on_step=False, on_epoch=True) 76 | self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) 77 | 78 | # we can return here dict with any tensors 79 | # and then read it in some callback or in `training_epoch_end()` below 80 | # remember to always return loss from `training_step()` or backpropagation will fail! 81 | return {"loss": loss, "preds": preds, "targets": targets} 82 | 83 | def on_train_epoch_end(self): 84 | # `outputs` is a list of dicts returned from `training_step()` 85 | 86 | # Warning: when overriding `training_epoch_end()`, lightning accumulates outputs from all batches of the epoch 87 | # this may not be an issue when training on mnist 88 | # but on larger datasets/models it's easy to run into out-of-memory errors 89 | 90 | # consider detaching tensors before returning them from `training_step()` 91 | # or using `on_train_epoch_end()` instead which doesn't accumulate outputs 92 | 93 | pass 94 | 95 | def validation_step(self, batch: Any, batch_idx: int): 96 | loss, preds, targets = self.model_step(batch) 97 | 98 | # update and log metrics 99 | self.val_loss(loss) 100 | self.val_acc(preds, targets) 101 | self.log("val/loss", self.val_loss, on_step=False, on_epoch=True) 102 | self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) 103 | 104 | return {"loss": loss, "preds": preds, "targets": targets} 105 | 106 | def on_validation_epoch_end(self): 107 | acc = self.val_acc.compute() # get current val acc 108 | self.val_acc_best(acc) # update best so far val acc 109 | # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object 110 | # otherwise metric would be reset by lightning after each epoch 111 | self.log("val/acc_best", self.val_acc_best.compute()) 112 | 113 | def test_step(self, batch: Any, batch_idx: int): 114 | loss, preds, targets = self.model_step(batch) 115 | 116 | # update and log metrics 117 | self.test_loss(loss) 118 | self.test_acc(preds, targets) 119 | self.log("test/loss", self.test_loss, on_step=False, on_epoch=True) 120 | self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) 121 | 122 | return {"loss": loss, "preds": preds, "targets": targets} 123 | 124 | def on_test_epoch_end(self): 125 | pass 126 | 127 | 128 | if __name__ == "__main__": 129 | m = SimpleNetModule() 130 | print(m) 131 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Custom 2 | 3 | demo.py 4 | .env 5 | logs/* 6 | !logs/.gitkeep 7 | data/* 8 | !data/.gitkeep 9 | 10 | ### VisualStudioCode template 11 | .vscode/* 12 | !.vscode/settings.json 13 | !.vscode/tasks.json 14 | !.vscode/launch.json 15 | !.vscode/extensions.json 16 | !.vscode/*.code-snippets 17 | 18 | # Local History for Visual Studio Code 19 | .history/ 20 | 21 | # Built Visual Studio Code Extensions 22 | *.vsix 23 | 24 | ### Windows template 25 | # Windows thumbnail cache files 26 | Thumbs.db 27 | Thumbs.db:encryptable 28 | ehthumbs.db 29 | ehthumbs_vista.db 30 | 31 | # Dump file 32 | *.stackdump 33 | 34 | # Folder config file 35 | [Dd]esktop.ini 36 | 37 | # Recycle Bin used on file shares 38 | $RECYCLE.BIN/ 39 | 40 | # Windows Installer files 41 | *.cab 42 | *.msi 43 | *.msix 44 | *.msm 45 | *.msp 46 | 47 | # Windows shortcuts 48 | *.lnk 49 | 50 | ### JupyterNotebooks template 51 | # gitignore template for Jupyter Notebooks 52 | # website: http://jupyter.org/ 53 | 54 | .ipynb_checkpoints 55 | */.ipynb_checkpoints/* 56 | 57 | # IPython 58 | profile_default/ 59 | ipython_config.py 60 | 61 | # Remove previous ipynb_checkpoints 62 | # git rm -r .ipynb_checkpoints/ 63 | 64 | ### macOS template 65 | # General 66 | .DS_Store 67 | .AppleDouble 68 | .LSOverride 69 | 70 | # Icon must end with two \r 71 | Icon 72 | 73 | # Thumbnails 74 | ._* 75 | 76 | # Files that might appear in the root of a volume 77 | .DocumentRevisions-V100 78 | .fseventsd 79 | .Spotlight-V100 80 | .TemporaryItems 81 | .Trashes 82 | .VolumeIcon.icns 83 | .com.apple.timemachine.donotpresent 84 | 85 | # Directories potentially created on remote AFP share 86 | .AppleDB 87 | .AppleDesktop 88 | Network Trash Folder 89 | Temporary Items 90 | .apdisk 91 | 92 | ### Python template 93 | # Byte-compiled / optimized / DLL files 94 | __pycache__/ 95 | *.py[cod] 96 | *$py.class 97 | 98 | # C extensions 99 | *.so 100 | 101 | # Distribution / packaging 102 | .Python 103 | build/ 104 | develop-eggs/ 105 | dist/ 106 | downloads/ 107 | eggs/ 108 | .eggs/ 109 | lib/ 110 | lib64/ 111 | parts/ 112 | sdist/ 113 | var/ 114 | wheels/ 115 | share/python-wheels/ 116 | *.egg-info/ 117 | .installed.cfg 118 | *.egg 119 | MANIFEST 120 | 121 | # PyInstaller 122 | # Usually these files are written by a python script from a template 123 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 124 | *.manifest 125 | *.spec 126 | 127 | # Installer logs 128 | pip-log.txt 129 | pip-delete-this-directory.txt 130 | 131 | # Unit test / coverage reports 132 | htmlcov/ 133 | .tox/ 134 | .nox/ 135 | .coverage 136 | .coverage.* 137 | .cache 138 | nosetests.xml 139 | coverage.xml 140 | *.cover 141 | *.py,cover 142 | .hypothesis/ 143 | .pytest_cache/ 144 | cover/ 145 | 146 | # Translations 147 | *.mo 148 | *.pot 149 | 150 | # Django stuff: 151 | *.log 152 | local_settings.py 153 | db.sqlite3 154 | db.sqlite3-journal 155 | 156 | # Flask stuff: 157 | instance/ 158 | .webassets-cache 159 | 160 | # Scrapy stuff: 161 | .scrapy 162 | 163 | # Sphinx documentation 164 | docs/_build/ 165 | 166 | # PyBuilder 167 | .pybuilder/ 168 | target/ 169 | 170 | # Jupyter Notebook 171 | .ipynb_checkpoints 172 | 173 | # IPython 174 | profile_default/ 175 | ipython_config.py 176 | 177 | # pyenv 178 | # For a library or package, you might want to ignore these files since the code is 179 | # intended to run in multiple environments; otherwise, check them in: 180 | # .python-version 181 | 182 | # pipenv 183 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 184 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 185 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 186 | # install all needed dependencies. 187 | #Pipfile.lock 188 | 189 | # poetry 190 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 191 | # This is especially recommended for binary packages to ensure reproducibility, and is more 192 | # commonly ignored for libraries. 193 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 194 | #poetry.lock 195 | 196 | # pdm 197 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 198 | #pdm.lock 199 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 200 | # in version control. 201 | # https://pdm.fming.dev/#use-with-ide 202 | .pdm.toml 203 | 204 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 205 | __pypackages__/ 206 | 207 | # Celery stuff 208 | celerybeat-schedule 209 | celerybeat.pid 210 | 211 | # SageMath parsed files 212 | *.sage.py 213 | 214 | # Environments 215 | .env 216 | .venv 217 | env/ 218 | venv/ 219 | ENV/ 220 | env.bak/ 221 | venv.bak/ 222 | 223 | # Spyder project settings 224 | .spyderproject 225 | .spyproject 226 | 227 | # Rope project settings 228 | .ropeproject 229 | 230 | # mkdocs documentation 231 | /site 232 | 233 | # mypy 234 | .mypy_cache/ 235 | .dmypy.json 236 | dmypy.json 237 | 238 | # Pyre type checker 239 | .pyre/ 240 | 241 | # pytype static type analyzer 242 | .pytype/ 243 | 244 | # Cython debug symbols 245 | cython_debug/ 246 | 247 | # PyCharm 248 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 249 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 250 | # and can be added to the global gitignore or merged into this file. For a more nuclear 251 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 252 | .idea/ 253 | 254 | -------------------------------------------------------------------------------- /src/utils/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from typing import Any, Dict, Optional, Type 4 | 5 | from lightning_fabric.utilities.cloud_io import get_filesystem 6 | from pytorch_lightning import LightningModule, Trainer 7 | from pytorch_lightning.cli import ( 8 | LightningArgumentParser, 9 | LightningCLI, 10 | LRSchedulerTypeUnion, 11 | ReduceLROnPlateau, 12 | SaveConfigCallback, 13 | ) 14 | from pytorch_lightning.loggers import Logger, WandbLogger 15 | from torch.optim import Optimizer 16 | from torch.optim.lr_scheduler import CyclicLR, OneCycleLR 17 | 18 | 19 | class WandbSaveConfigCallback(SaveConfigCallback): 20 | def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: 21 | if self.already_saved: 22 | return 23 | 24 | log_dir = trainer.log_dir # this broadcasts the directory 25 | if trainer.logger is not None and trainer.logger.name is not None and trainer.logger.version is not None: 26 | log_dir = os.path.join(log_dir, trainer.logger.name, str(trainer.logger.version)) 27 | config_path = os.path.join(log_dir, self.config_filename) 28 | fs = get_filesystem(log_dir) 29 | 30 | if not self.overwrite: 31 | # check if the file exists on rank 0 32 | file_exists = fs.isfile(config_path) if trainer.is_global_zero else False 33 | # broadcast whether to fail to all ranks 34 | file_exists = trainer.strategy.broadcast(file_exists) 35 | if file_exists: 36 | raise RuntimeError( 37 | f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting" 38 | " results of a previous run. You can delete the previous config file," 39 | " set `LightningCLI(save_config_callback=None)` to disable config saving," 40 | ' or set `LightningCLI(save_config_kwargs={"overwrite": True})` to overwrite the config file.' 41 | ) 42 | 43 | # save the file on rank 0 44 | if trainer.is_global_zero: 45 | # save only on rank zero to avoid race conditions. 46 | # the `log_dir` needs to be created as we rely on the logger to do it usually 47 | # but it hasn't logged anything at this point 48 | fs.makedirs(log_dir, exist_ok=True) 49 | self.parser.save( 50 | self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile 51 | ) 52 | self.already_saved = True 53 | # save optimizer and lr scheduler config 54 | for _logger in trainer.loggers: 55 | if isinstance(_logger, Logger): 56 | config = {} 57 | if "optimizer" in self.config: 58 | config["optimizer"] = { 59 | k.replace("init_args.", ""): v for k, v in dict(self.config["optimizer"]).items() 60 | } 61 | if "lr_scheduler" in self.config: 62 | config["lr_scheduler"] = { 63 | k.replace("init_args.", ""): v for k, v in dict(self.config["lr_scheduler"]).items() 64 | } 65 | _logger.log_hyperparams(config) 66 | 67 | # broadcast so that all ranks are in sync on future calls to .setup() 68 | self.already_saved = trainer.strategy.broadcast(self.already_saved) 69 | 70 | 71 | class CustomLightningCLI(LightningCLI): 72 | def __init__( 73 | self, 74 | save_config_callback: Optional[Type[SaveConfigCallback]] = WandbSaveConfigCallback, 75 | parser_kwargs: Optional[Dict[str, Any]] = None, 76 | **kwargs: Any, 77 | ) -> None: 78 | new_parser_kwargs = { 79 | sub_command: dict(default_config_files=[os.path.join("configs", "default.yaml")]) 80 | for sub_command in ["fit", "validate", "test", "predict"] 81 | } 82 | new_parser_kwargs.update(parser_kwargs or {}) 83 | super().__init__(save_config_callback=save_config_callback, parser_kwargs=new_parser_kwargs, **kwargs) 84 | 85 | def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: 86 | parser.add_argument("--ignore_warnings", default=False, type=bool, help="Ignore warnings") 87 | parser.add_argument("--git_commit_before_fit", default=False, type=bool, help="Git commit before training") 88 | parser.add_argument( 89 | "--test_after_fit", default=False, type=bool, help="Run test on the best checkpoint after training" 90 | ) 91 | 92 | def before_instantiate_classes(self) -> None: 93 | if self.config[self.subcommand].get("ignore_warnings"): 94 | warnings.filterwarnings("ignore") 95 | 96 | def before_fit(self) -> None: 97 | if self.config.fit.get("git_commit_before_fit") and not os.environ.get("DEBUG", False): 98 | logger = self.trainer.logger 99 | if isinstance(logger, WandbLogger): 100 | version = getattr(logger, "version") 101 | name = getattr(logger, "_name") 102 | message = "Commit Message" 103 | if name and version: 104 | message = f"{name}_{version}" 105 | elif name: 106 | message = name 107 | elif version: 108 | message = version 109 | os.system(f'git commit -am "{message}"') 110 | 111 | def after_fit(self) -> None: 112 | if self.config.fit.get("test_after_fit") and not os.environ.get("DEBUG", False): 113 | self._run_subcommand("test") 114 | 115 | def before_test(self) -> None: 116 | if self.trainer.checkpoint_callback and self.trainer.checkpoint_callback.best_model_path: 117 | tested_ckpt_path = self.trainer.checkpoint_callback.best_model_path 118 | elif self.config_init[self.config_init["subcommand"]]["ckpt_path"]: 119 | return 120 | else: 121 | tested_ckpt_path = None 122 | self.config_init[self.config_init["subcommand"]]["ckpt_path"] = tested_ckpt_path 123 | 124 | def _prepare_subcommand_kwargs(self, subcommand: str) -> Dict[str, Any]: 125 | """Prepares the keyword arguments to pass to the subcommand to run.""" 126 | fn_kwargs = { 127 | k: v 128 | for k, v in self.config_init[self.config_init["subcommand"]].items() 129 | if k in self._subcommand_method_arguments[subcommand] 130 | } 131 | fn_kwargs["model"] = self.model 132 | if self.datamodule is not None: 133 | fn_kwargs["datamodule"] = self.datamodule 134 | return fn_kwargs 135 | 136 | @staticmethod 137 | def configure_optimizers( 138 | lightning_module: LightningModule, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None 139 | ) -> Any: 140 | """Override to customize the :meth:`~pytorch_lightning.core.LightningModule.configure_optimizers` method. 141 | 142 | Args: 143 | lightning_module: A reference to the model. 144 | optimizer: The optimizer. 145 | lr_scheduler: The learning rate scheduler (if used). 146 | 147 | """ 148 | if lr_scheduler is None: 149 | return optimizer 150 | if isinstance(lr_scheduler, ReduceLROnPlateau): 151 | return { 152 | "optimizer": optimizer, 153 | "lr_scheduler": {"scheduler": lr_scheduler, "monitor": lr_scheduler.monitor}, 154 | } 155 | if isinstance(lr_scheduler, (OneCycleLR, CyclicLR)): 156 | # CyclicLR and OneCycleLR are step-based schedulers, where the default interval is "epoch". 157 | return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"}} 158 | return [optimizer], [lr_scheduler] 159 | -------------------------------------------------------------------------------- /README_ZH.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Pytorch Lightning Template 4 | 5 | [![Pytorch](https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/) 6 | [![Pytorch Lightning](https://img.shields.io/badge/-Lightning-ffffff?logo=)](https://lightning.ai/docs/pytorch/stable/) 7 | [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) 8 | [![License](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE) 9 | 10 | 一个使用 Lightning 全家桶的简单深度学习项目模板 11 | 12 | [English](./README.md) | 中文 13 | 14 |
15 | 16 | ## 介绍 17 | 18 | [Pytorch Lightning](https://lightning.ai/docs/pytorch/stable/)之于深度学习项目开发就如同[MVC](https://en.wikipedia.org/wiki/Model%E2%80%93view%E2%80%93controller)框架(例如[Spring](https://spring.io/),[Django](https://www.djangoproject.com/)等)之于网站开发。你当然可以手写所有实现并获得最大的灵活度(尤其是[PyTorch](https://pytorch.org/)及其生态已经足够简单),但使用框架能够帮助你在已有[“最佳实践”](#最佳实践)(仅代表个人观点)的指导下快速实现原型,通过复用省去大量模板代码(boilerplate),从而专注于科研创新而不是工程难题。该模板使用`lightning`全家桶构建,尽量遵循[奥卡姆剃刀原则](https://zh.wikipedia.org/zh-hans/%E5%A5%A5%E5%8D%A1%E5%A7%86%E5%89%83%E5%88%80)以及对科研人员友好,并实现了一个简单的手写数字识别任务[MNIST](https://en.wikipedia.org/wiki/MNIST_database)。该仓库还记录了一些在使用过程中的[Tips](#tips),以供参考。 19 | 20 | ## “最佳实践” 21 | 22 | ### 使用 [Pytorch Lightning](https://lightning.ai/docs/pytorch/stable/) 作为深度学习框架 23 | 24 | 大部分的深度学习的代码都可以分为以下这三部分([参考](https://zhuanlan.zhihu.com/p/120331610)): 25 | 26 | 1. 研究代码(Research code) 27 | 这部分属于模型部分,一般处理模型的结构、训练等定制化部分。在`Linghtning`中,这部分代码抽象为 `pl.LightningModule`类。数据集的定义也可以放在这里,但是不推荐,因为这部分代码和实验无关,应该放在`pl.LightningDataModule`中 28 | 2. 工程代码(Engineering code) 29 | 这部分代码很重要的特点是:重复性强,比如说设置 Early Stopping、16 位精度、GPU 分布训练。在`Linghtning`中,这部分抽象为`pl.Trainer`类 30 | 3. 非必要代码(Non-essential code) 31 | 这部分代码有利于实验的进行,但是和实验没有直接关系,甚至可以不使用。比如说检查梯度、给`TensorBoard`输出 log。在`Linghtning`中,这部分抽象为`Callbacks`类,注册到`pl.Trainer` 32 | 33 | `Lightning`的优势有: 34 | 35 | 1. 通过`pl.LightningModule`中的各种[钩子函数(hook)](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks),可以实现自定义的训练过程,自定义的学习率调整策略等 36 | 2. 模型和数据不再需要显式地指定设备(`tensor.to`,`tensor.cuda`等),`pl.Trainer`会自动处理,从而支持各种[加速设备(CPU,GPU,TPU 等)](https://lightning.ai/docs/pytorch/latest/extensions/accelerator.html) 37 | 3. `pl.Trainer`实现了多种[训练策略](https://lightning.ai/docs/pytorch/latest/extensions/strategy.html),如自动混合精度训练,多卡训练,分布式训练等 38 | 4. `pl.Trainer`实现了多种[回调(callback)](https://lightning.ai/docs/pytorch/latest/extensions/callbacks.html),如自动保存模型,自动保存日志,自动保存可视化结果等 39 | 40 | ### 使用 [Pytorch Lightning CLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html#lightning-cli) 作为命令行工具 41 | 42 | 1. 使用`lightning_cli`作为程序入口,可以通过配置文件或命令行参数设置模型、数据、训练等参数,从而实现多个实验的快速切换 43 | 2. 使用`pl.LightningModule.save_hyperparameters()`保存模型的超参数,自动生成命令行参数表,无需用[`argparse`](https://docs.python.org/3/library/argparse.html)或[`hydra`](https://hydra.cc/)等工具手动实现 44 | 45 | ### 使用 [Torchmetrics](https://torchmetrics.readthedocs.io/en/stable/) 作为指标计算工具 46 | 47 | 1. 自带了多种指标计算方法,如`Accuracy`,`Precision`,`Recall`等 48 | 2. 与`Lightning`集成,兼容并行训练策略,数据自动汇集到主进程计算指标 49 | 50 | ### [可选] 使用[WanDB](https://wandb.ai/)来记录实验日志 51 | 52 | ### 项目架构 53 | 54 | ```mermaid 55 | graph TD; 56 | A[LightningCLI]---B[LightningModule] 57 | A---C[LightningDataModule] 58 | B---D[models] 59 | B---E[metrics] 60 | B---F[...] 61 | C---G[dataloaders] 62 | G---H[datasets] 63 | ``` 64 | 65 | ### 文件结构 66 | 67 | ```text 68 | . 69 | ├── configs # 配置文件 70 | │ ├── data # 数据集配置 71 | │ │ └── mnist.yaml # MNIST数据集配置示例 72 | │ ├── model # 模型配置 73 | │ │ └── simplenet.yaml # SimpleNet模型配置示例 74 | │ └── default.yaml # 默认配置 75 | ├── data # 数据集目录 76 | ├── logs # 日志目录 77 | ├── notebooks # Jupyter Notebook目录 78 | ├── scripts # 脚本目录 79 | │ └── clear_wandb_cache.py # 清除wandb缓存脚本示例 80 | ├── src # 源代码目录 81 | │ ├── callbacks # 回调函数目录 82 | │ │ └── __init__.py 83 | │ ├── data_modules # 数据集模块目录 84 | │ │ ├── __init__.py 85 | │ │ └── mnist.py # MNIST数据集模块示例 86 | │ ├── metrics # 指标目录 87 | │ │ └── __init__.py 88 | │ ├── models # 模型目录 89 | │ │ ├── __init__.py 90 | │ │ └── simplenet.py # SimpleNet模型示例 91 | │ ├── modules # 模块目录 92 | │ │ ├── __init__.py 93 | │ │ └── mnist_module.py # MNIST模块示例 94 | │ ├── utils # 工具目录 95 | │ │ ├── __init__.py 96 | │ │ └── cli.py # CLI工具 97 | │ ├── __init__.py 98 | │ └── main.py # 程序主入口 99 | ├── .env.example # 环境变量示例 100 | ├── .gitignore # Git忽略列表 101 | ├── .project-root # pyrootutils项目根目录指示文件 102 | ├── LICENSE # 开源协议 103 | ├── pyproject.toml # Black和Ruff的配置文件 104 | ├── README_PROJECT.md # 项目说明模板 105 | ├── README.md # 本项目说明 106 | ├── README_ZH.md # 本项目说明(中文) 107 | └── requirements.txt # 依赖列表 108 | ``` 109 | 110 | ## 使用 111 | 112 | ### 安装 113 | 114 | ```bash 115 | # 克隆项目 116 | git clone https://github.com/DavidZhang73/pytorch-lightning-template 117 | cd 118 | 119 | # 安装 uv, https://docs.astral.sh/uv/getting-started/installation/ 120 | curl -LsSf https://astral.sh/uv/install.sh | sh 121 | 122 | # 安装依赖 123 | uv sync 124 | ``` 125 | 126 | ### 配置 127 | 128 | 1. 在`src/data_module`中继承[`pl.LightningDataModule`](https://lightning.ai/docs/pytorch/stable/data/datamodule.html)定义数据集 129 | 2. 在`configs/data`中定义数据集的配置文件(作为自定义的`pl.LightningDataModule`的参数) 130 | 3. 在`src/models`中继承`nn.Module`定义模型 131 | 4. 在`src/metrics`中继承[`torchmetrics.Metric`](https://torchmetrics.readthedocs.io/en/stable/pages/implement.html)定义指标 132 | 5. 在`src/modules`中继承[`pl.LightningModule`](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html)定义训练模块 133 | 6. 在`configs/model`中定义训练模块的配置文件(作为自定义的`pl.LightningModule`的参数) 134 | 7. 配置`configs/default.yaml`中的[`pl.trainer`](https://lightning.ai/docs/pytorch/stable/common/pl.trainer.html),日志等参数 135 | 136 | ### 运行 137 | 138 | **训练** 139 | 140 | ```bash 141 | python src/main.py fit -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.logger.name exp1 142 | ``` 143 | 144 | **验证** 145 | 146 | ```bash 147 | python src/main.py validate -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.logger.name exp1 148 | ``` 149 | 150 | **测试** 151 | 152 | ```bash 153 | python src/main.py test -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.logger.name exp1 154 | ``` 155 | 156 | **推理** 157 | 158 | ```bash 159 | python src/main.py predict -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.logger.name exp1 160 | ``` 161 | 162 | **调试** 163 | 164 | ```bash 165 | python src/main.py fit -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.fast_dev_run true 166 | ``` 167 | 168 | **恢复训练** 169 | 170 | ```bash 171 | python src/main.py fit -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --ckpt_path --trainer.logger.id exp1_id 172 | ``` 173 | 174 | ## Tips 175 | 176 | ### 获得解析的参数,生成默认`yaml`文件 177 | 178 | 利用`jsonargparse`的`print_config`功能,可以获得解析的参数,生成默认`yaml`文件等,不过需要先配置好`data`和`model`的`yaml`文件, 179 | 180 | ```bash 181 | python src/main.py fit -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --print_config 182 | ``` 183 | 184 | [Prepare a config file for the CLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_advanced.html#prepare-a-config-file-for-the-cli) 185 | 186 | ### 自定义`LightningCLI` 187 | 188 | 该模板实现了一个自定义的`CLI`(`CustomLightningCLI`)以实现以下功能, 189 | 190 | - 每次启动时,自动保存配置文件到对应的日志文件目录下, 仅实现了对`WandbLogger`的适配 191 | - 每次启动时,把`optimizer`和`lr_scheduler`的参数保存到`Loggers`中 192 | - 每次启动时,自动加载默认配置文件 193 | - 测试完成后,打印测试使用的`checkpoint_path` 194 | - 添加了一些命令行参数: 195 | - `--ignore_warnings`(默认: `False`):忽略全部警告 196 | - `--test_after_fit`(默认: `False`):每次训练完成后,自动测试一遍 197 | - `--git_commit_before_fit`(默认: `False`):每次训练之前`git commit`一下,`commit`的信息为`{logger.name}_{logger.version}`, 仅实现了对`WandbLogger`的适配 198 | 199 | [CONFIGURE HYPERPARAMETERS FROM THE CLI (EXPERT)](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_expert.html) 200 | 201 | ### 限制`numpy`的进程数量 202 | 203 | 当运行在服务器上,尤其是 CPU 拥有很多核(>=24)时,可能遇到`numpy`的进程数量过多的问题,可能导致实验莫名其妙卡死,可以设置环境变量的方式(在`.env`文件中)限制`numpy`的进程数量。 204 | 205 | ```text 206 | OMP_NUM_THREADS=8 207 | MKL_NUM_THREADS=8 208 | GOTO_NUM_THREADS=8 209 | NUMEXPR_NUM_THREADS=8 210 | OPENBLAS_NUM_THREADS=8 211 | MKL_DOMAIN_NUM_THREADS=8 212 | VECLIB_MAXIMUM_THREADS=8 213 | ``` 214 | 215 | > `.env`文件由[`pyrootutils`](https://github.com/ashleve/pyrootutils)库通过[`python-dotenv`](https://github.com/theskumar/python-dotenv)自动加载到环境变量中 216 | 217 | [Stack Overflow: Limit number of threads in numpy](https://stackoverflow.com/questions/30791550/limit-number-of-threads-in-numpy) 218 | 219 | ### 清除`wandb`的缓存 220 | 221 | 当你从`wandb`的网页上删除了一个实验,但是本地的`wandb`目录下仍然存在该实验的缓存时,可以使用脚本`scripts/clear_wandb_cache.py`清除缓存。 222 | 223 | [Wandb Python Documentation](https://docs.wandb.ai/ref/python/) 224 | 225 | ## 参考 226 | 227 | 灵感来源于 228 | 229 | - [deep-learning-project-template](https://github.com/Lightning-AI/deep-learning-project-template) 230 | - [lightning-hydra-template](https://github.com/ashleve/lightning-hydra-template) 231 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Pytorch Lightning Template 4 | 5 | [![Pytorch](https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/) 6 | [![Pytorch Lightning](https://img.shields.io/badge/-Lightning-ffffff?logo=)](https://lightning.ai/docs/pytorch/stable/) 7 | [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) 8 | [![License](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE) 9 | 10 | A template for simple deep learning projects using Lightning 11 | 12 | English | [中文](./README_ZH.md) 13 | 14 |
15 | 16 | ## Introduction 17 | 18 | [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) is to deep learning project development as [MVC](https://en.wikipedia.org/wiki/Model%E2%80%93view%E2%80%93controller) frameworks (such as [Spring](https://spring.io/), [Django](https://www.djangoproject.com/), etc.) are to website development. While it is possible to implement everything from scratch and achieve maximum flexibility (especially since [PyTorch](https://pytorch.org/) and its ecosystem are already quite straightforward), using a framework can help you quickly implement prototypes with guidance from ["best practices"](#best-practice) (personal opinion) to save a lot of boilerplate code through re-usability, and focus on scientific innovation rather than engineering challenges. This template is built using the full Lightning suite, follows the principle of [Occam's razor](https://en.wikipedia.org/wiki/Occam%27s_razor), and is friendly to researchers. It also includes a simple handwritten digit recognition task using the MNIST dataset. The repository also contains some [Tips](#tips), for reference. 19 | 20 | ## "Best Practice" 21 | 22 | ### Using [Pytorch Lightning](https://lightning.ai/docs/pytorch/stable/) as a deep learning framework: 23 | 24 | Most of the deep learning code can be divided into the following three parts([Reference [Chinese]](https://zhuanlan.zhihu.com/p/120331610)): 25 | 26 | 1. Research code: This part pertains to the model and generally deals with customizations of the model's structure and training. In `Lightning`, this code is abstracted as the `pl.LightningModule` class. While dataset definition can also be included in this part, it is not recommended as it is not relevant to the experiment and should be included in `pl.LightningDataModule` instead. 27 | 28 | 2. Engineering code: This part of the code is essential for its high repeatability, such as setting early stopping, 16-bit precision, and GPU distributed training. In `Lightning`, this code is abstracted as the `pl.Trainer` class. 29 | 30 | 3. Non-essential code: This code is helpful in conducting experiments but is not directly related to the experiment itself, and can even be omitted. For example, gradient checking and outputting logs to `TensorBoard`. In Lightning, this code is abstracted as the `Callbacks` class, which is registered to `pl.Trainer`. 31 | 32 | The advantages of using `Lightning`: 33 | 34 | 1. Custom training processes and learning rate adjustment strategies can be implemented through various [hook functions](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks) in `pl.LightningModule`. 35 | 36 | 2. The model and data no longer need to be explicitly designated for devices (`tensor.to`, `tensor.cuda`, etc.). `pl.Trainer` handles this automatically, thereby supporting various [acceleration devices such as CPU, GPU, and TPU](https://lightning.ai/docs/pytorch/latest/extensions/accelerator.html). 37 | 38 | 3. `pl.Trainer` implements various [training strategies](https://lightning.ai/docs/pytorch/latest/extensions/strategy.html), such as automatic mixed precision training, multi-GPU training, and distributed training. 39 | 40 | 4. `pl.Trainer` implements multiple [callbacks](https://lightning.ai/docs/pytorch/latest/extensions/callbacks.html) such as automatic model saving, automatic config saving, and automatic visualization result saving. 41 | 42 | ### Using [Pytorch Lightning CLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html#lightning-cli) as a command-line tool: 43 | 44 | 1. Using `lightning_cli` as the program entry point, model, data, and training parameters can be set through configuration files or command-line parameters, thereby achieving quick switching between multiple experiments. 45 | 46 | 2. `pl.LightningModule.save_hyperparameters()` saves the model's hyperparameters and automatically generates a command-line parameter table, eliminating the need for tools such as [`argparse`](https://docs.python.org/3/library/argparse.html) or [`hydra`](https://hydra.cc/). 47 | 48 | ### Using [Torchmetrics](https://torchmetrics.readthedocs.io/en/stable/) as a metric computation tool: 49 | 50 | 1. `Torchmetrics` provides multiple metric calculation methods such as `Accuracy`, `Precision`, and `Recall`. 51 | 52 | 2. It is integrated with `Lightning` and is compatible with parallel training strategies. Data is automatically aggregated to the main process for metric computation. 53 | 54 | ### [Optional] Using [WanDB](https://wandb.ai/) to track experiments 55 | 56 | ### Project Architecture 57 | 58 | ```mermaid 59 | graph TD; 60 | A[LightningCLI]---B[LightningModule] 61 | A---C[LightningDataModule] 62 | B---D[models] 63 | B---E[metrics] 64 | B---F[...] 65 | C---G[dataloaders] 66 | G---H[datasets] 67 | ``` 68 | 69 | ### File Structure 70 | 71 | ```text 72 | ├── configs # Configuration files 73 | │ ├── data # Dataset configuration 74 | │ │ └── mnist.yaml # Example configuration for MNIST dataset 75 | │ ├── model # Model configuration 76 | │ │ └── simplenet.yaml # Example configuration for SimpleNet model 77 | │ └── default.yaml # Default configuration 78 | ├── data # Dataset directory 79 | ├── logs # Log directory 80 | ├── notebooks # Jupyter Notebook directory 81 | ├── scripts # Script directory 82 | │ └── clear_wandb_cache.py # Example script to clear wandb cache 83 | ├── src # Source code directory 84 | │ ├── callbacks # Callbacks directory 85 | │ │ └── __init__.py 86 | │ ├── data_modules # Data module directory 87 | │ │ ├── __init__.py 88 | │ │ └── mnist.py # Example data module for MNIST dataset 89 | │ ├── metrics # Metrics directory 90 | │ │ └── __init__.py 91 | │ ├── models # Model directory 92 | │ │ ├── __init__.py 93 | │ │ └── simplenet.py # Example SimpleNet model 94 | │ ├── modules # Module directory 95 | │ │ ├── __init__.py 96 | │ │ └── mnist_module.py # Example MNIST module 97 | │ ├── utils # Utility directory 98 | │ │ ├── __init__.py 99 | │ │ └── cli.py # CLI tool 100 | │ ├── __init__.py 101 | │ └── main.py # Main program entry point 102 | ├── .env.example # Example environment variable file 103 | ├── .gitignore # Ignore files for git 104 | ├── .project-root # Project root indicator file for pyrootutils 105 | ├── LICENSE # Open source license 106 | ├── pyproject.toml # Configuration file for Black and Ruff 107 | ├── README.md # Project documentation 108 | ├── README_PROJECT.md # Project documentation template 109 | ├── README_ZH.md # Project documentation in Chinese 110 | └── requirements.txt # Dependency list 111 | ``` 112 | 113 | ## Usage 114 | 115 | ### Installation 116 | 117 | ```bash 118 | # Clone project 119 | git clone https://github.com/DavidZhang73/pytorch-lightning-template 120 | cd 121 | 122 | # Install uv, https://docs.astral.sh/uv/getting-started/installation/ 123 | curl -LsSf https://astral.sh/uv/install.sh | sh 124 | 125 | # Install dependencies 126 | uv sync 127 | ``` 128 | 129 | ### Configuration 130 | 131 | 1. Define dataset by inheriting `pl.LightningDataModule` in `src/data_module`. 132 | 2. Define dataset configuration file in `configs/data` as parameters for the custom `pl.LightningDataModule`. 133 | 3. Define the model by inheriting `nn.Module` in `src/models`. 134 | 4. Define metrics by inheriting `torchmetrics.Metric` in `src/metrics`. 135 | 5. Define training module by inheriting `pl.LightningModule` in `src/modules`. 136 | 6. Define the configuration file for the training module in `configs/model` as parameters for the custom `pl.LightningModule`. 137 | 7. Configure `pl.trainer`, logs and other parameters in `configs/default.yaml`. 138 | 139 | ### Run 140 | 141 | **Fit** 142 | 143 | ```bash 144 | python src/main.py fit -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.logger.name exp1 145 | ``` 146 | 147 | **Validate** 148 | 149 | ```bash 150 | python src/main.py validate -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.logger.name exp1 151 | ``` 152 | 153 | **Test** 154 | 155 | ```bash 156 | python src/main.py test -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.logger.name exp1 157 | ``` 158 | 159 | **Inference** 160 | 161 | ```bash 162 | python src/main.py predict -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.logger.name exp1 163 | ``` 164 | 165 | **Debug** 166 | 167 | ```bash 168 | python src/main.py fit -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.fast_dev_run true 169 | ``` 170 | 171 | **Resume** 172 | 173 | ```bash 174 | python src/main.py fit -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --ckpt_path --trainer.logger.id exp1_id 175 | ``` 176 | 177 | ## Tips 178 | 179 | ### `print_config` 180 | 181 | Using the `print_config` functionality of `jsonargparse`, you can obtain the parsed arguments and generate default `yaml` files. However, it is necessary to first configure the `yaml` files for `data` and `model`. 182 | 183 | ```bash 184 | python src/main.py fit -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --print_config 185 | ``` 186 | 187 | [Prepare a config file for the CLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_advanced.html#prepare-a-config-file-for-the-cli) 188 | 189 | ### Customized `LightningCLI` 190 | 191 | This template implements a custom `CLI` (`CustomLightningCLI`) to achieve the following functions, 192 | 193 | - When starting the program, the configuration file is automatically saved to the corresponding log directory, for `WandbLogger` only. 194 | - When starting the program, save configurations for optimizer and scheduler to loggers. 195 | - When starting the program, the default configuration file is automatically loaded. 196 | - After the test is completed, the `checkpoint_path` used for testing is printed. 197 | - Add some command line parameters: 198 | - `--ignore_warnings` (default: `False`): Ignore all warnings. 199 | - `--test_after_fit` (default: `False`): Automatically test after each training. 200 | - `--git_commit_before_fit` (default: `False`): `git commit` before each training, the commit message is `{logger.name}_{logger.version}`, for `WandbLogger` only. 201 | 202 | [CONFIGURE HYPERPARAMETERS FROM THE CLI (EXPERT)](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_expert.html) 203 | 204 | ### Limit `numpy` number of threads 205 | 206 | When running on a server, especially when the CPU has a lot of cores (>=24), you may encounter the problem of too many `numpy` processes, which may cause the experiment to inexplicably hang. You can limit the number of `numpy` processes by setting environment variables (in the `.env` file). 207 | 208 | ```text 209 | OMP_NUM_THREADS=8 210 | MKL_NUM_THREADS=8 211 | GOTO_NUM_THREADS=8 212 | NUMEXPR_NUM_THREADS=8 213 | OPENBLAS_NUM_THREADS=8 214 | MKL_DOMAIN_NUM_THREADS=8 215 | VECLIB_MAXIMUM_THREADS=8 216 | ``` 217 | 218 | > `.env` file is automatically loaded to environment by [`pyrootutils`](https://github.com/ashleve/pyrootutils) via [`python-dotenv`](https://github.com/theskumar/python-dotenv). 219 | 220 | [Stack Overflow: Limit number of threads in numpy](https://stackoverflow.com/questions/30791550/limit-number-of-threads-in-numpy) 221 | 222 | ### Clear `wandb` cache 223 | 224 | When you delete an experiment from the `wandb` web page, the cache of the experiment still exists in the local `wandb` directory, you can use the `scripts/clear_wandb_cache.py` script to clear the cache. 225 | 226 | [Wandb Python Documentation](https://docs.wandb.ai/ref/python/) 227 | 228 | ## References 229 | 230 | Inspired by, 231 | 232 | - [deep-learning-project-template](https://github.com/Lightning-AI/deep-learning-project-template) 233 | - [lightning-hydra-template](https://github.com/ashleve/lightning-hydra-template) 234 | --------------------------------------------------------------------------------