├── .gitignore ├── README.md ├── models └── networks.py ├── opt.py ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | ckpts/ 2 | logs/ 3 | MNIST/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-lightning-tutorial 2 | Pytorch lightning tutorial using MNIST 3 | 4 | [Youtube stream](https://www.youtube.com/watch?v=O7dNXpgdWbo&ab_channel=AI%E8%91%B5) 5 | (Maybe there will be another... still planning!) 6 | 7 | [Pytorch lightning introduction](https://github.com/PyTorchLightning/pytorch-lightning) 8 | 9 | [scheduler introduction (Japanese)](https://katsura-jp.hatenablog.com/entry/2019/01/30/183501) 10 | 11 | # Installation 12 | 13 | Python>=3.7, creation using anaconda is recommended. Install libraries by `pip install -r requirements.txt`. 14 | 15 | # Train MNIST 16 | 17 | Run (example) 18 | ```python3 19 | python train.py --root_dir "./" 20 | ``` 21 | 22 | It will download the dataset to `root_dir` and start training. You can monitor the training process by launching tensorboard in another terminal: 23 | ```python3 24 | tensorboard --logdir logs 25 | ``` 26 | 27 | And go to `localhost:6006` in your browser. 28 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import rearrange, reduce, repeat 5 | 6 | 7 | class LinearModel(nn.Module): 8 | def __init__(self, hidden_dim): 9 | super().__init__() 10 | self.net = nn.Sequential( 11 | nn.Linear(28 * 28, hidden_dim), 12 | nn.ReLU(True), 13 | nn.Linear(hidden_dim, 10) 14 | ) 15 | 16 | def forward(self, x): 17 | """ 18 | x: (B, 1, 28, 28) batch of images 19 | """ 20 | x = rearrange(x, 'b 1 x y -> b (x y)', x=28, y=28) 21 | return self.net(x) -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_opts(): 4 | parser = argparse.ArgumentParser() 5 | 6 | parser.add_argument('--root_dir', type=str, required=True, 7 | help='root directory of dataset') 8 | parser.add_argument('--hidden_dim', type=int, default=128, 9 | help='number of hidden dimensions') 10 | 11 | parser.add_argument('--val_size', type=int, default=5000, 12 | help='size of validation set') 13 | 14 | parser.add_argument('--batch_size', type=int, default=128, 15 | help='number of batch size') 16 | parser.add_argument('--lr', type=float, default=1e-4, 17 | help='learning rate') 18 | parser.add_argument('--num_epochs', type=int, default=10, 19 | help='number of epochs') 20 | parser.add_argument('--num_workers', type=int, default=4, 21 | help='number of workers for data loader') 22 | 23 | parser.add_argument('--exp_name', type=str, default='exp', 24 | help='experiment name') 25 | 26 | return parser.parse_args() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.11.0 2 | torchvision==0.12.0 3 | pytorch_lightning==1.6.0 4 | einops==0.4.1 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from opt import get_opts 5 | 6 | # datasets 7 | from torchvision.datasets import MNIST 8 | from torch.utils.data import DataLoader, random_split 9 | from torchvision import transforms as T 10 | 11 | # models 12 | from models.networks import LinearModel 13 | 14 | # optimizer 15 | from torch.optim import Adam 16 | from torch.optim.lr_scheduler import CosineAnnealingLR 17 | 18 | from pytorch_lightning import LightningModule, Trainer, seed_everything 19 | from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar 20 | from pytorch_lightning.loggers import TensorBoardLogger 21 | 22 | seed_everything(1234, workers=True) 23 | 24 | 25 | def get_learning_rate(optimizer): 26 | for param_group in optimizer.param_groups: 27 | return param_group['lr'] 28 | 29 | 30 | class MNISTSystem(LightningModule): 31 | def __init__(self, hparams): 32 | super().__init__() 33 | self.save_hyperparameters(hparams) 34 | self.net = LinearModel(self.hparams.hidden_dim) 35 | 36 | def forward(self, x): 37 | return self.net(x) 38 | 39 | def prepare_data(self): 40 | """ 41 | download data once 42 | """ 43 | MNIST(self.hparams.root_dir, train=True, download=True) 44 | MNIST(self.hparams.root_dir, train=False, download=True) 45 | 46 | def setup(self, stage=None): 47 | """ 48 | setup dataset for each machine 49 | """ 50 | dataset = MNIST(self.hparams.root_dir, 51 | train=True, 52 | download=False, 53 | transform=T.ToTensor()) 54 | train_length = len(dataset) # 60000 55 | self.train_dataset, self.val_dataset = \ 56 | random_split(dataset, 57 | [train_length-self.hparams.val_size, self.hparams.val_size]) 58 | 59 | def train_dataloader(self): 60 | return DataLoader(self.train_dataset, 61 | shuffle=True, 62 | num_workers=self.hparams.num_workers, 63 | batch_size=self.hparams.batch_size, 64 | pin_memory=True) 65 | 66 | def val_dataloader(self): 67 | return DataLoader(self.val_dataset, 68 | shuffle=False, 69 | num_workers=self.hparams.num_workers, 70 | batch_size=self.hparams.batch_size, 71 | pin_memory=True) 72 | 73 | def configure_optimizers(self): 74 | self.optimizer = Adam(self.net.parameters(), lr=self.hparams.lr) 75 | 76 | scheduler = CosineAnnealingLR(self.optimizer, 77 | T_max=self.hparams.num_epochs, 78 | eta_min=self.hparams.lr/1e2) 79 | 80 | return [self.optimizer], [scheduler] 81 | 82 | def training_step(self, batch, batch_idx): 83 | images, labels = batch 84 | logits_predicted = self(images) 85 | 86 | loss = F.cross_entropy(logits_predicted, labels) 87 | 88 | self.log('lr', get_learning_rate(self.optimizer)) 89 | self.log('train/loss', loss) 90 | 91 | return loss 92 | 93 | def validation_step(self, batch, batch_idx): 94 | images, labels = batch 95 | logits_predicted = self(images) 96 | 97 | loss = F.cross_entropy(logits_predicted, labels) 98 | acc = torch.sum(torch.eq(torch.argmax(logits_predicted, -1), labels).to(torch.float32)) / len(labels) 99 | 100 | log = {'val_loss': loss, 101 | 'val_acc': acc} 102 | 103 | return log 104 | 105 | def validation_epoch_end(self, outputs): 106 | mean_loss = torch.stack([x['val_loss'] for x in outputs]).mean() 107 | mean_acc = torch.stack([x['val_acc'] for x in outputs]).mean() 108 | 109 | self.log('val/loss', mean_loss, prog_bar=True) 110 | self.log('val/acc', mean_acc, prog_bar=True) 111 | 112 | 113 | if __name__ == '__main__': 114 | hparams = get_opts() 115 | mnistsystem = MNISTSystem(hparams) 116 | 117 | ckpt_cb = ModelCheckpoint(dirpath=f'ckpts/{hparams.exp_name}', 118 | filename='{epoch:d}', 119 | monitor='val/acc', 120 | mode='max', 121 | save_top_k=5) 122 | pbar = TQDMProgressBar(refresh_rate=1) 123 | callbacks = [ckpt_cb, pbar] 124 | 125 | logger = TensorBoardLogger(save_dir="logs", 126 | name=hparams.exp_name, 127 | default_hp_metric=False) 128 | 129 | trainer = Trainer(max_epochs=hparams.num_epochs, 130 | callbacks=callbacks, 131 | logger=logger, 132 | enable_model_summary=True, 133 | accelerator='auto', 134 | devices=1, 135 | num_sanity_val_steps=1, 136 | benchmark=True) 137 | 138 | trainer.fit(mnistsystem) --------------------------------------------------------------------------------