├── .gitignore ├── CITATION.cff ├── INSTALL.md ├── LICENCE ├── README.md ├── assets └── Pytorch Boiler.png ├── example_projects ├── __init__.py ├── image_autoencoder │ ├── __init__.py │ ├── dataloader.py │ ├── model.py │ └── train.py └── image_classifier │ ├── __init__.py │ ├── dataloader.py │ ├── model.py │ └── train.py └── pytorch_boiler ├── __init__.py ├── boiler.py ├── tracker.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | data/ 3 | model/ -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # This CITATION.cff file was generated with cffinit. 2 | # Visit https://bit.ly/cffinit to generate yours today! 3 | 4 | cff-version: 1.2.0 5 | title: Pytorch Boiler 6 | message: github.com/nmakes/pytorch_boiler 7 | type: software 8 | authors: 9 | - given-names: Naveen 10 | family-names: Venkat 11 | email: nav.naveenvenkat@gmail.com 12 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Install 2 | 3 | ## 1. Install / update conda 4 | 5 | For a fresh conda installation: 6 | 7 | 1. [Download conda version](https://docs.conda.io/en/latest/miniconda.html#linux-installers) 8 | 2. Run installer: `bash Miniconda3-latest-Linux-x86_64.sh` 9 | 10 | If you have an existing conda installation, update it using: 11 | 12 | ``` 13 | conda update -n base -c defaults conda 14 | ``` 15 | 16 | ## 2. Creating an environment 17 | 18 | Create a conda environment and install the dependencies: 19 | 20 | ``` 21 | conda create -n boiler python=3.9 22 | conda activate boiler 23 | conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia 24 | python3 -m pip install matplotlib 25 | ``` 26 | 27 | ## 3. Download boiler 28 | 29 | ``` 30 | git clone git@github.com:nmakes/pytorch_boiler.git 31 | ``` 32 | 33 | ## 4. Run sample model training 34 | ``` 35 | cd pytorch_boiler 36 | PYTHONPATH=$PYTHONPATH:./ python3 -m \ 37 | example_projects.image_classifier.train 38 | ``` -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Naveen Venkat 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | [![Python 3 version](https://img.shields.io/badge/python-%3E%3D3.6-blue)](https://www.python.org/downloads/release/python-360/) 6 | [![Pytorch version](https://img.shields.io/badge/pytorch-%3E%3D%201.4.0-informational)](https://pytorch.org/get-started/previous-versions/) 7 | 8 | 9 | [![Code Size](https://img.shields.io/github/languages/code-size/nmakes/pytorch_boiler)](https://github.com/nmakes/pytorch_boiler/) 10 | 11 | [![LICENCE](https://img.shields.io/badge/licence-MIT-blueviolet)](LICENCE) 12 | 13 | 14 | # Introduction 15 | Pytorch Boiler is a minimalistic boiler plate code for training pytorch models. 16 | 17 | ## Quick Start 18 | 19 | * Clone this repository 20 | 21 | ```bash 22 | git clone https://github.com/nmakes/pytorch_boiler 23 | cd pytorch_boiler 24 | ``` 25 | 26 | * Run sample experiments 27 | 28 | 40-line [MNIST/CIFAR classification](example_projects/image_classifier/train.py): 29 | 30 | ```bash 31 | PYTHONPATH=$PYTHONPATH:./ python3 -m example_projects.image_classifier.train 32 | ``` 33 | 34 | 50-line [MNIST/CIFAR autoencoder](example_projects/image_autoencoder/train.py): 35 | 36 | ```bash 37 | PYTHONPATH=$PYTHONPATH:./ python3 -m example_projects.image_autoencoder.train 38 | ``` 39 | 40 | ## Installation 41 | 42 | Basic Requirements: 43 | 44 | * `numpy` 45 | * `pytorch` 46 | * `torchvision` 47 | 48 | Other Requirements: 49 | 50 | * `nvidia-apex` [[install]](https://github.com/NVIDIA/apex#from-source) (for mixed-precision training) 51 | 52 | 53 | ## Supported Functionalities 54 | 55 | * Customizable Train / Inference engine with forward and infer modes 56 | * Tracking multiple training / validation losses and metrics 57 | * Loading / Saving model, optimizer and trackers based on validation loss 58 | * Training MNIST / CIFAR in 40-lines (see [train.py](example_projects/image_classifier/train.py)) 59 | * Supports Apex Amp for mixed precision training 60 | 61 | 62 | ## TODO 63 | 64 | * Support for multiple loss optimization using multiple optimizers 65 | * Support for tensorboard plots 66 | * Support for torchscript 67 | * Documentation: using metrics and losses beyond the example script 68 | 69 | 70 | # Cite 71 | 72 | If you found Pytorch Boiler useful in your project, please cite the following: 73 | 74 | ``` 75 | @software{Venkat_Pytorch_Boiler, 76 | author = {Venkat, Naveen}, 77 | title = {{Pytorch Boiler}}, 78 | url = {https://github.com/nmakes/pytorch_boiler}, 79 | year = {2022} 80 | } 81 | ``` 82 | 83 | Thanks! -------------------------------------------------------------------------------- /assets/Pytorch Boiler.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmakes/pytorch_boiler/f60c547fec69de33e69a0689ede526f4be73bbb3/assets/Pytorch Boiler.png -------------------------------------------------------------------------------- /example_projects/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmakes/pytorch_boiler/f60c547fec69de33e69a0689ede526f4be73bbb3/example_projects/__init__.py -------------------------------------------------------------------------------- /example_projects/image_autoencoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmakes/pytorch_boiler/f60c547fec69de33e69a0689ede526f4be73bbb3/example_projects/image_autoencoder/__init__.py -------------------------------------------------------------------------------- /example_projects/image_autoencoder/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import torchvision.datasets as tvd 3 | import torchvision.transforms as tvt 4 | 5 | 6 | def cifar10_dataloader(root='./data', train=True, batch_size=4, drop_last=True, shuffle=True, num_workers=4): 7 | if train: 8 | transforms = tvt.Compose([ 9 | tvt.RandomHorizontalFlip(), 10 | tvt.RandomResizedCrop(size=32, scale=(0.08, 1)), 11 | tvt.ToTensor(), 12 | tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 13 | ]) 14 | else: 15 | transforms = tvt.Compose([ 16 | tvt.ToTensor(), 17 | tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 18 | ]) 19 | 20 | loader = DataLoader( 21 | dataset=tvd.CIFAR10(root=root, train=train, transform=transforms, download=True), 22 | batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_last 23 | ) 24 | 25 | return loader 26 | 27 | 28 | def mnist_dataloader(root='./data', train=True, batch_size=4, drop_last=True, shuffle=True, num_workers=4): 29 | if train: 30 | transforms = tvt.Compose([ 31 | tvt.Resize((32, 32)), 32 | tvt.ToTensor(), 33 | tvt.Normalize(0.5, 0.5) 34 | ]) 35 | else: 36 | transforms = tvt.Compose([ 37 | tvt.Resize((32, 32)), 38 | tvt.ToTensor(), 39 | tvt.Normalize(0.5, 0.5) 40 | ]) 41 | 42 | loader = DataLoader( 43 | dataset=tvd.MNIST(root=root, train=train, transform=transforms, download=True), 44 | batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_last 45 | ) 46 | 47 | return loader 48 | -------------------------------------------------------------------------------- /example_projects/image_autoencoder/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class GAPFlatten(nn.Module): 6 | 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x): 11 | x = x.mean(dim=-1) # b, c, h 12 | x = x.mean(dim=-1) # b, c 13 | return x # b, c 14 | 15 | 16 | class GAPUnflatten(nn.Module): 17 | 18 | def __init__(self, output_shape=(4, 4)): 19 | super().__init__() 20 | self.h, self.w = output_shape 21 | self.linear = nn.Linear(1, self.h * self.w) 22 | 23 | def forward(self, x): 24 | b, c = x.shape[:2] 25 | x = x.reshape(b, c, 1) # (b, c, 1) 26 | x = self.linear(x) # (b, c, h*w) 27 | x = x.reshape(b, c, self.h, self.w) # (b, c, h, w) 28 | return x 29 | 30 | 31 | class TinyResNetAE(nn.Module): 32 | 33 | def __init__(self, in_channels, hidden_channels, expansion_factor=2, latent_image_size=(4, 4)): 34 | super().__init__() 35 | exf = expansion_factor 36 | 37 | self.encoder = nn.Sequential( 38 | nn.Conv2d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1), 39 | nn.BatchNorm2d(hidden_channels), 40 | nn.ReLU(), # 32 41 | nn.MaxPool2d(kernel_size=exf, stride=exf), # 16 42 | 43 | nn.Conv2d(in_channels=hidden_channels, out_channels=hidden_channels * exf, kernel_size=3, stride=1, padding=1), 44 | nn.BatchNorm2d(hidden_channels * exf), 45 | nn.ReLU(), 46 | nn.MaxPool2d(kernel_size=exf, stride=exf), # 8 47 | 48 | nn.Conv2d(in_channels=hidden_channels * exf, out_channels=hidden_channels * (exf**2), kernel_size=3, stride=1, padding=1), 49 | nn.BatchNorm2d(hidden_channels * (exf**2)), 50 | nn.ReLU(), 51 | nn.MaxPool2d(kernel_size=exf, stride=exf), # 4 52 | 53 | nn.Conv2d(in_channels=hidden_channels * (exf**2), out_channels=hidden_channels * (exf**2), kernel_size=3, stride=1, padding=1), # 4 54 | GAPFlatten(), 55 | ) 56 | 57 | self.decoder = nn.Sequential( 58 | GAPUnflatten(output_shape=latent_image_size), 59 | 60 | nn.Conv2d(in_channels=hidden_channels * (exf**2), out_channels=hidden_channels * exf, kernel_size=3, stride=1, padding=1), 61 | nn.BatchNorm2d(hidden_channels * exf), 62 | nn.ReLU(), 63 | nn.Upsample(scale_factor=2, mode='bilinear'), # 8 64 | 65 | nn.Conv2d(in_channels=hidden_channels * exf, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1), 66 | nn.BatchNorm2d(hidden_channels), 67 | nn.ReLU(), 68 | nn.Upsample(scale_factor=2, mode='bilinear'), # 16 69 | 70 | nn.Conv2d(in_channels=hidden_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1), 71 | nn.BatchNorm2d(in_channels), 72 | nn.ReLU(), 73 | nn.Upsample(scale_factor=2, mode='bilinear'), # 32 74 | 75 | nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1), 76 | nn.Tanh() # Images will be normalized between [-1, 1] 77 | ) 78 | 79 | def forward(self, x): 80 | encoding = self.encoder(x) 81 | decoded_image = self.decoder(encoding) 82 | return encoding, decoded_image 83 | -------------------------------------------------------------------------------- /example_projects/image_autoencoder/train.py: -------------------------------------------------------------------------------- 1 | """Example usage of Pytorch Boiler""" 2 | import torch 3 | 4 | from pytorch_boiler import Boiler, overload 5 | from .model import TinyResNetAE 6 | from .dataloader import mnist_dataloader, cifar10_dataloader 7 | 8 | 9 | class Trainer(Boiler): 10 | 11 | @overload 12 | def pre_process(self, data): 13 | images, labels = data 14 | return images.cuda() 15 | 16 | @overload 17 | def loss(self, model_output, data): # Overload the loss function 18 | images, labels = data 19 | latent_encoding, predicted_images = model_output 20 | l2_loss = ((images.cuda() - predicted_images) ** 2).mean() 21 | return l2_loss # Can return a tensor, or a dictinoary like {'xe_loss': xe_loss} with multiple losses. See README. 22 | 23 | @overload 24 | def performance(self, model_output, data): 25 | images, labels = data 26 | latent_encoding, predicted_images = model_output 27 | diff = (images.cuda() - predicted_images) ** 2 28 | reconstruction_error = diff.mean() 29 | max_error = diff.mean(dim=(1, 2, 3)).amax(dim=0) 30 | min_error = diff.mean(dim=(1, 2, 3)).amin(dim=0) 31 | return_values = { 32 | 'avg_reconstruction_error': reconstruction_error, 33 | 'avg_max_error': max_error, 34 | 'avg_min_error': min_error, 35 | 'summary': reconstruction_error 36 | } 37 | return return_values # Can return a tensor, or a dictinoary like {'acc': acc} with multiple metrics. See README. 38 | 39 | 40 | if __name__ == '__main__': 41 | dataloader = mnist_dataloader; dataset = 'mnist'; in_channels = 1; batch_size = 256; exp_tag='autoencoder_mnist' 42 | # dataloader = cifar10_dataloader; dataset = 'cifar10'; in_channels = 3; batch_size = 256; exp_tag='autoencoder_cifar10' 43 | model = TinyResNetAE(in_channels=in_channels, hidden_channels=4, expansion_factor=2, latent_image_size=(4, 4)).cuda() 44 | optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3, weight_decay=1e-5) 45 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10000, gamma=1) 46 | train_dataloader = dataloader(root=f'./data/{dataset}', batch_size=batch_size, train=True, shuffle=True, drop_last=True) 47 | val_dataloader = dataloader(root=f'./data/{dataset}', batch_size=batch_size, train=False, shuffle=False, drop_last=False) 48 | 49 | trainer = Trainer(model=model, optimizer=optimizer, scheduler=scheduler, 50 | train_dataloader=train_dataloader, val_dataloader=val_dataloader, 51 | epochs=10, save_path=f'./model/{exp_tag}/state_dict.pt', load_path=None, mixed_precision=True) 52 | trainer.fit() 53 | -------------------------------------------------------------------------------- /example_projects/image_classifier/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmakes/pytorch_boiler/f60c547fec69de33e69a0689ede526f4be73bbb3/example_projects/image_classifier/__init__.py -------------------------------------------------------------------------------- /example_projects/image_classifier/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import torchvision.datasets as tvd 3 | import torchvision.transforms as tvt 4 | 5 | 6 | def cifar10_dataloader(root='./data', train=True, batch_size=4, drop_last=True, shuffle=True, num_workers=4): 7 | if train: 8 | transforms = tvt.Compose([ 9 | tvt.RandomHorizontalFlip(), 10 | tvt.RandomResizedCrop(size=32, scale=(0.08, 1)), 11 | tvt.ToTensor(), 12 | tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 13 | ]) 14 | else: 15 | transforms = tvt.Compose([ 16 | tvt.ToTensor(), 17 | tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 18 | ]) 19 | 20 | loader = DataLoader( 21 | dataset=tvd.CIFAR10(root=root, train=train, transform=transforms, download=True), 22 | batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_last 23 | ) 24 | 25 | return loader 26 | 27 | 28 | def mnist_dataloader(root='./data', train=True, batch_size=4, drop_last=True, shuffle=True, num_workers=4): 29 | if train: 30 | transforms = tvt.Compose([ 31 | tvt.ToTensor(), 32 | tvt.Normalize(0.5, 0.5) 33 | ]) 34 | else: 35 | transforms = tvt.Compose([ 36 | tvt.ToTensor(), 37 | tvt.Normalize(0.5, 0.5) 38 | ]) 39 | 40 | loader = DataLoader( 41 | dataset=tvd.MNIST(root=root, train=train, transform=transforms, download=True), 42 | batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_last 43 | ) 44 | 45 | return loader 46 | -------------------------------------------------------------------------------- /example_projects/image_classifier/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class TinyResNet(nn.Module): 6 | 7 | def __init__(self, in_channels, hidden_channels, output_channels, num_layers, expansion_factor=2): 8 | super().__init__() 9 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1) 10 | 11 | self.conv_blocks = nn.ModuleList([ 12 | nn.Sequential( 13 | nn.Conv2d(hidden_channels * (2**i), hidden_channels * (2**i), kernel_size=3, stride=1, padding=1), 14 | nn.BatchNorm2d(hidden_channels * (2**i)), 15 | nn.ReLU(), 16 | nn.Conv2d(hidden_channels * (2**i), hidden_channels * (2**(i+1)), kernel_size=3, stride=1, padding=1), 17 | nn.BatchNorm2d(hidden_channels * (2**(i+1))), 18 | nn.ReLU(), 19 | nn.Conv2d(hidden_channels * (2**(i+1)), hidden_channels * (2**(i+1)), kernel_size=3, stride=1, padding=1), 20 | nn.BatchNorm2d(hidden_channels * (2**(i+1))), 21 | nn.ReLU(), 22 | ) 23 | for i in range(num_layers) 24 | ]) 25 | 26 | self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) 27 | 28 | self.skip_connections = nn.ModuleList([ 29 | nn.Conv2d(hidden_channels * (2**i), hidden_channels * (2**(i+1)), kernel_size=3, stride=1, padding=1) 30 | for i in range(num_layers) 31 | ]) 32 | 33 | self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 34 | 35 | self.fc = nn.Sequential( 36 | nn.Linear(hidden_channels * (2**(num_layers)), hidden_channels * (2**(num_layers))), 37 | nn.ReLU(), 38 | nn.Linear(hidden_channels * (2**(num_layers)), output_channels) 39 | ) 40 | 41 | def forward(self, x): 42 | x = self.conv1(x) # (B, C, H, W) 43 | 44 | for conv_block, skip_conn in zip(self.conv_blocks, self.skip_connections): 45 | block_out = conv_block(x) 46 | skip_out = skip_conn(x) 47 | x = block_out + skip_out 48 | x = self.max_pool(x) 49 | 50 | x = self.global_avg_pool(x) # (B, C, 1, 1) 51 | x = x.squeeze(-1).squeeze(-1) 52 | 53 | x = self.fc(x) 54 | return x 55 | -------------------------------------------------------------------------------- /example_projects/image_classifier/train.py: -------------------------------------------------------------------------------- 1 | """Example usage of Pytorch Boiler""" 2 | import torch 3 | 4 | from pytorch_boiler import Boiler, overload 5 | from .model import TinyResNet 6 | from .dataloader import mnist_dataloader, cifar10_dataloader 7 | 8 | 9 | class Trainer(Boiler): 10 | 11 | @overload 12 | def pre_process(self, data): 13 | images, labels = data 14 | return images.cuda() 15 | 16 | @overload 17 | def loss(self, model_output, data): # Overload the loss function 18 | images, labels = data 19 | xe_loss = torch.nn.functional.cross_entropy(model_output, labels.cuda(), reduction='mean') 20 | return xe_loss # Can return a tensor, or a dictinoary like {'xe_loss': xe_loss} with multiple losses. See README. 21 | 22 | @overload 23 | def performance(self, model_output, data): 24 | images, labels = data 25 | preds = model_output.argmax(dim=-1) 26 | acc = (preds == labels.cuda()).float().mean() 27 | return acc.cpu().detach().numpy() # Can return a tensor, or a dictinoary like {'acc': acc} with multiple metrics. See README. 28 | 29 | 30 | if __name__ == '__main__': 31 | # dataloader = mnist_dataloader; dataset = 'mnist'; in_channels = 1; batch_size = 256; exp_tag='classifier_mnist' 32 | dataloader = cifar10_dataloader; dataset = 'cifar10'; in_channels = 3; batch_size = 256; exp_tag='classifier_cifar10' 33 | model = TinyResNet(in_channels=in_channels, hidden_channels=4, output_channels=10, num_layers=3, expansion_factor=2).cuda() 34 | optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4) 35 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10000, gamma=1) 36 | train_dataloader = dataloader(root=f'./data/{dataset}', batch_size=batch_size, train=True, shuffle=True, drop_last=True) 37 | val_dataloader = dataloader(root=f'./data/{dataset}', batch_size=batch_size, train=False, shuffle=False, drop_last=False) 38 | 39 | trainer = Trainer(model=model, optimizer=optimizer, scheduler=scheduler, 40 | train_dataloader=train_dataloader, val_dataloader=val_dataloader, 41 | epochs=10, save_path=f'./model/{exp_tag}/state_dict.pt', load_path=None, mixed_precision=True) 42 | trainer.fit() 43 | -------------------------------------------------------------------------------- /pytorch_boiler/__init__.py: -------------------------------------------------------------------------------- 1 | from .boiler import ( 2 | Boiler 3 | ) 4 | 5 | from .utils import ( 6 | overload 7 | ) 8 | 9 | from .tracker import ( 10 | Tracker 11 | ) 12 | 13 | __all__ = [ 14 | # Boiler imports 15 | 'Boiler', 16 | 17 | # Util imports 18 | 'overload', 19 | 20 | # Tracker imports 21 | 'Tracker' 22 | ] 23 | -------------------------------------------------------------------------------- /pytorch_boiler/boiler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from tqdm import tqdm 8 | from typing import Any 9 | 10 | from .utils import init_overload_state, is_method_overloaded, prettify_dict 11 | from .tracker import Tracker 12 | 13 | 14 | class Boiler(nn.Module): 15 | 16 | def __init__(self, model, optimizer, scheduler, train_dataloader, val_dataloader, epochs, patience=None, save_path=None, load_path=None, mixed_precision=False): 17 | super(Boiler, self).__init__() 18 | self.model = model 19 | self.optimizer = optimizer 20 | self.scheduler = scheduler 21 | self.train_dataloader = train_dataloader 22 | self.val_dataloader = val_dataloader 23 | self.epochs = epochs 24 | self.patience = patience 25 | 26 | self.save_path = save_path 27 | self.load_path = load_path 28 | 29 | self.mixed_precision = mixed_precision 30 | 31 | # Initialize tracker 32 | self.tracker = Tracker() 33 | self.patience_counter = 0 34 | 35 | # Initialize mixed precision 36 | self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision) 37 | 38 | # Load model state 39 | self.start_epoch = 0 40 | self.best_validation_loss = float('inf') 41 | if self.load_path is not None: 42 | print('\nBoiler | Loading from {}'.format(self.load_path)) 43 | loaded_object = torch.load(self.load_path) 44 | self.model.load_state_dict(loaded_object['model_state_dict']) 45 | self.optimizer.load_state_dict(loaded_object['optimizer_state_dict']) 46 | self.scaler.load_state_dict(loaded_object['scaler_state_dict']) 47 | self.scheduler.load_state_dict(loaded_object['scheduler_state_dict']) 48 | self.tracker.load_state_dict(loaded_object['tracker']) 49 | self.start_epoch = loaded_object['epoch'] + 1 50 | self.best_validation_loss = loaded_object['best_validation_loss'] 51 | 52 | @init_overload_state 53 | def pre_process(self, x): 54 | """Computes the pre processing steps. Must be overloaded for usage. 55 | 56 | Args: 57 | x (Any): The input to the pre processor 58 | 59 | Raises: 60 | NotImplementedError: If this function is not implemented and is called. 61 | """ 62 | raise NotImplementedError() 63 | 64 | @init_overload_state 65 | def post_process(self, x): 66 | """Computes the post processing steps. Must be overloaded for usage. 67 | 68 | Args: 69 | x (Any): The input to the post processor (output from the model forward pass). 70 | 71 | Raises: 72 | NotImplementedError: If this function is not implemented and is called. 73 | """ 74 | raise NotImplementedError() 75 | 76 | @init_overload_state 77 | def forward(self, x): 78 | """Computes the forward pass, with additional pre and post processing. 79 | 80 | Args: 81 | x (torch.Tensor|Any): The tensor input to the model forward pass, or input to the pre processor. 82 | 83 | Returns: 84 | torch.Tensor|Any: The tensor output of the model, or the output from the post processor. 85 | """ 86 | if is_method_overloaded(self.pre_process): 87 | x = self.pre_process(x) 88 | with torch.autocast('cuda', dtype=torch.float16, enabled=(self.mixed_precision and self.training)): 89 | x = self.model(x) 90 | if is_method_overloaded(self.post_process): 91 | x = self.post_process(x) 92 | return x 93 | 94 | @init_overload_state 95 | def loss(self, output, data): 96 | """Computes the loss. Must be overloaded for usage. 97 | 98 | Args: 99 | output(torch.Tensor|Any): The output from the model forward pass, or output from the post procesor. 100 | data(torch.Tensor|Any): The data batch loaded from the dataloader. 101 | 102 | Raises: 103 | NotImplementedError: If this function is not implemented and is called. 104 | """ 105 | raise NotImplementedError() 106 | 107 | @init_overload_state 108 | def performance(self, output, data): 109 | """Computes performance metric. Must be overloaded for usage. 110 | 111 | Args: 112 | output (torch.Tensor|Any): The output from the model forward pass, or output from the post processor. 113 | data (torch.Tensor|Any): The data batch loaded from the dataloader. 114 | 115 | Returns: 116 | torch.Tensor|Any: The performance metric computed on the current batch. 117 | """ 118 | raise NotImplementedError() 119 | 120 | @init_overload_state 121 | def infer(self, x): 122 | """Computes the inference in eval mode. 123 | 124 | Args: 125 | x (torch.Tensor|Any): The tensor input to the model forward pass, or input to the pre processor. 126 | 127 | Returns: 128 | torch.Tensor|Any: The tensor output of the model, or the output from the post processor. 129 | """ 130 | # If model is in training mode, convert to eval 131 | training_state_true = self.training 132 | if training_state_true: 133 | self.eval() 134 | 135 | with torch.no_grad(): 136 | output = self.forward(x) 137 | 138 | # If model was in training mode, convert back to train from eval 139 | if training_state_true: 140 | self.train() 141 | 142 | return output 143 | 144 | @init_overload_state 145 | def decode_item_type(self, item, require_summary=True): 146 | """Decodes the type of the item. This function is used to handle multiple metrics 147 | 148 | Args: 149 | item (torch.Tensor|np.ndarray|dict): Either a tensor describing a loss/metric, or a dictionary of named tensors. 150 | require_summary (bool): If True then the item of type dictionary will be enforced to have a key 'summary' if more than one keys are present. 151 | 152 | Returns: 153 | dict: A dictionary with named tensors, where "summary" is a special key that specifies the loss to be backpropagated / metric to be tracked. 154 | """ 155 | assert type(item) in [torch.Tensor, np.ndarray, dict], f"Either a tensor, numpy array, or a dictionary of (tag: value) must be passed. Given {type(item)}." 156 | if type(item) == torch.Tensor or type(item) == np.ndarray: 157 | return {'summary': item} 158 | elif type(item) == dict: 159 | keys = list(item.keys()) 160 | if len(keys) == 1: 161 | item['summary'] = item[keys[0]] 162 | if require_summary: 163 | assert 'summary' in item.keys(), "If a dictionary is returned from loss / performance, it must contain the key called 'summary' indicating the loss to backpropagate / metric to track." 164 | return item 165 | 166 | @init_overload_state 167 | def train_epoch(self): 168 | """Trains the model for one epoch. 169 | 170 | Args: 171 | None 172 | 173 | Returns: 174 | Tracker: The tracker object. 175 | """ 176 | self.train() # Set the mode to training 177 | for data in tqdm(self.train_dataloader): 178 | # Set optimizer zero grad 179 | self.optimizer.zero_grad() 180 | 181 | # Compute model output 182 | model_output = self.forward(data) 183 | 184 | # Compute loss & Update tracker 185 | loss = self.loss(model_output, data) 186 | decoded_loss = self.decode_item_type(loss, require_summary=True) 187 | self.tracker.update(key='training_loss', 188 | value=decoded_loss['summary'].cpu().detach().numpy()) 189 | for key in decoded_loss: 190 | if key != 'summary': 191 | self.tracker.update(key='training_{}'.format(key), 192 | value=decoded_loss[key].cpu().detach().numpy()) 193 | 194 | # Compute performance and update tracker 195 | if is_method_overloaded(self.performance): 196 | perf = self.performance(model_output, data) 197 | decoded_perf = self.decode_item_type(perf, require_summary=True) 198 | self.tracker.update(key='training_perf', 199 | value=decoded_perf['summary'].cpu().detach().numpy() \ 200 | if type(decoded_perf['summary'])==torch.Tensor \ 201 | else decoded_perf['summary']) 202 | for key in decoded_perf: 203 | if key != 'summary': 204 | self.tracker.update(key='training_{}'.format(key), 205 | value=decoded_perf[key].cpu().detach().numpy() \ 206 | if type(decoded_perf[key])==torch.Tensor \ 207 | else decoded_perf[key]) 208 | 209 | if self.mixed_precision: 210 | # Scale the loss 211 | scaled_loss = self.scaler.scale(decoded_loss['summary']) 212 | # Backpropagate loss 213 | scaled_loss.backward() 214 | # Update model parameters 215 | self.scaler.step(self.optimizer) 216 | self.scaler.update() 217 | else: 218 | # Backpropagate loss 219 | decoded_loss['summary'].backward() 220 | # Update model parameters 221 | self.optimizer.step() 222 | return self.tracker 223 | 224 | @init_overload_state 225 | def eval_epoch(self): 226 | """Evaluates the model for one epoch. 227 | 228 | Args: 229 | None 230 | 231 | Returns: 232 | Tracker: The tracker object. 233 | """ 234 | self.eval() # Set the mode to evaluation 235 | for data in tqdm(self.val_dataloader): 236 | # Compute model output 237 | model_output = self.infer(data) 238 | 239 | # Compute loss & Update tracker 240 | loss = self.loss(model_output, data) 241 | decoded_loss = self.decode_item_type(loss, require_summary=True) 242 | self.tracker.update('validation_loss', decoded_loss['summary'].cpu().detach().numpy()) 243 | for key in decoded_loss: 244 | if key != 'summary': 245 | self.tracker.update('validation_{}'.format(key), decoded_loss[key].cpu().detach().numpy()) 246 | 247 | # Compute performance & Update tracker 248 | if is_method_overloaded(self.performance): 249 | perf = self.performance(model_output, data) 250 | decoded_perf = self.decode_item_type(perf, require_summary=True) 251 | self.tracker.update('validation_perf', decoded_perf['summary'].cpu().detach().numpy() if type(decoded_perf['summary'])==torch.Tensor else decoded_perf['summary']) 252 | for key in decoded_perf: 253 | if key != 'summary': 254 | self.tracker.update('validation_{}'.format(key), decoded_perf[key].cpu().detach().numpy() if type(decoded_perf[key])==torch.Tensor else decoded_perf[key]) 255 | 256 | return self.tracker 257 | 258 | @init_overload_state 259 | def fit(self): 260 | """Fits the model for the given number of epochs. 261 | 262 | Returns: 263 | Boiler: Self. 264 | """ 265 | best_validation_loss = self.best_validation_loss 266 | start_epoch = self.start_epoch 267 | 268 | for e in range(start_epoch, self.epochs): 269 | print('\nBoiler | Training epoch {}/{}...'.format(e+1, self.epochs)) 270 | self.train_epoch() 271 | self.eval_epoch() 272 | self.scheduler.step() 273 | summary = self.tracker.summarize() 274 | print(prettify_dict(summary)) 275 | self.tracker.stash() 276 | 277 | if summary['validation_loss'] < best_validation_loss: 278 | self.patience_counter = 0 279 | to_save = True 280 | else: 281 | self.patience_counter += 1 282 | to_save = False 283 | 284 | if self.save_path is not None and to_save: 285 | best_validation_loss = summary['validation_loss'] 286 | print('Saving training state at {}'.format(self.save_path)) 287 | os.makedirs(os.path.abspath(os.path.dirname(self.save_path)), exist_ok=True) 288 | saved_object = { 289 | 'model_state_dict': self.model.state_dict(), 290 | 'optimizer_state_dict': self.optimizer.state_dict(), 291 | 'scaler_state_dict': self.scaler.state_dict(), 292 | 'scheduler_state_dict': self.scheduler.state_dict(), 293 | 'tracker': self.tracker.state_dict(), 294 | 'epoch': e, 295 | 'best_validation_loss': best_validation_loss 296 | } 297 | torch.save(saved_object, self.save_path) 298 | 299 | if (self.patience is not None) and (self.patience_counter > self.patience): 300 | print('Maximum patience reached. Stopping training.') 301 | break 302 | 303 | return self 304 | -------------------------------------------------------------------------------- /pytorch_boiler/tracker.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import json 4 | 5 | 6 | class Tracker: 7 | 8 | def __init__(self): 9 | self.history = {} 10 | self.running_history = [] 11 | 12 | def update(self, key, value): 13 | if key not in self.history: 14 | self.history[key] = [] 15 | self.history[key].append(value) 16 | 17 | def stash(self): 18 | self.running_history.append(self.history) 19 | self.history = {} 20 | 21 | def summarize(self): 22 | summary = {} 23 | for key in self.history: 24 | summary[key] = np.mean(self.history[key]).item() 25 | return summary 26 | 27 | def state_dict(self): 28 | state = { 29 | 'history': self.history, 30 | 'running_history': self.running_history 31 | } 32 | return state 33 | 34 | def load_state_dict(self, state_dict): 35 | self.history = state_dict['history'] 36 | self.running_history = state_dict['running_history'] 37 | return self -------------------------------------------------------------------------------- /pytorch_boiler/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | # ----------------------------- 5 | # Utilities to override methods 6 | # ----------------------------- 7 | def overload(func): 8 | func.is_overloaded = True 9 | return func 10 | 11 | def is_method_overloaded(func): 12 | if not hasattr(func, 'is_overloaded'): 13 | return False 14 | else: 15 | return func.is_overloaded 16 | 17 | def init_overload_state(func): 18 | if not is_method_overloaded(func): 19 | func.is_overloaded = False 20 | return func 21 | 22 | 23 | # --------------- 24 | # Text formatting 25 | # --------------- 26 | def prettify_dict(d): 27 | return json.dumps(d, indent=2) 28 | --------------------------------------------------------------------------------