├── .gitignore
├── CITATION.cff
├── INSTALL.md
├── LICENCE
├── README.md
├── assets
└── Pytorch Boiler.png
├── example_projects
├── __init__.py
├── image_autoencoder
│ ├── __init__.py
│ ├── dataloader.py
│ ├── model.py
│ └── train.py
└── image_classifier
│ ├── __init__.py
│ ├── dataloader.py
│ ├── model.py
│ └── train.py
└── pytorch_boiler
├── __init__.py
├── boiler.py
├── tracker.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | data/
3 | model/
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | # This CITATION.cff file was generated with cffinit.
2 | # Visit https://bit.ly/cffinit to generate yours today!
3 |
4 | cff-version: 1.2.0
5 | title: Pytorch Boiler
6 | message: github.com/nmakes/pytorch_boiler
7 | type: software
8 | authors:
9 | - given-names: Naveen
10 | family-names: Venkat
11 | email: nav.naveenvenkat@gmail.com
12 |
--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Install
2 |
3 | ## 1. Install / update conda
4 |
5 | For a fresh conda installation:
6 |
7 | 1. [Download conda version](https://docs.conda.io/en/latest/miniconda.html#linux-installers)
8 | 2. Run installer: `bash Miniconda3-latest-Linux-x86_64.sh`
9 |
10 | If you have an existing conda installation, update it using:
11 |
12 | ```
13 | conda update -n base -c defaults conda
14 | ```
15 |
16 | ## 2. Creating an environment
17 |
18 | Create a conda environment and install the dependencies:
19 |
20 | ```
21 | conda create -n boiler python=3.9
22 | conda activate boiler
23 | conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia
24 | python3 -m pip install matplotlib
25 | ```
26 |
27 | ## 3. Download boiler
28 |
29 | ```
30 | git clone git@github.com:nmakes/pytorch_boiler.git
31 | ```
32 |
33 | ## 4. Run sample model training
34 | ```
35 | cd pytorch_boiler
36 | PYTHONPATH=$PYTHONPATH:./ python3 -m \
37 | example_projects.image_classifier.train
38 | ```
--------------------------------------------------------------------------------
/LICENCE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Naveen Venkat
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | [](https://www.python.org/downloads/release/python-360/)
6 | [](https://pytorch.org/get-started/previous-versions/)
7 |
8 |
9 | [](https://github.com/nmakes/pytorch_boiler/)
10 |
11 | [](LICENCE)
12 |
13 |
14 | # Introduction
15 | Pytorch Boiler is a minimalistic boiler plate code for training pytorch models.
16 |
17 | ## Quick Start
18 |
19 | * Clone this repository
20 |
21 | ```bash
22 | git clone https://github.com/nmakes/pytorch_boiler
23 | cd pytorch_boiler
24 | ```
25 |
26 | * Run sample experiments
27 |
28 | 40-line [MNIST/CIFAR classification](example_projects/image_classifier/train.py):
29 |
30 | ```bash
31 | PYTHONPATH=$PYTHONPATH:./ python3 -m example_projects.image_classifier.train
32 | ```
33 |
34 | 50-line [MNIST/CIFAR autoencoder](example_projects/image_autoencoder/train.py):
35 |
36 | ```bash
37 | PYTHONPATH=$PYTHONPATH:./ python3 -m example_projects.image_autoencoder.train
38 | ```
39 |
40 | ## Installation
41 |
42 | Basic Requirements:
43 |
44 | * `numpy`
45 | * `pytorch`
46 | * `torchvision`
47 |
48 | Other Requirements:
49 |
50 | * `nvidia-apex` [[install]](https://github.com/NVIDIA/apex#from-source) (for mixed-precision training)
51 |
52 |
53 | ## Supported Functionalities
54 |
55 | * Customizable Train / Inference engine with forward and infer modes
56 | * Tracking multiple training / validation losses and metrics
57 | * Loading / Saving model, optimizer and trackers based on validation loss
58 | * Training MNIST / CIFAR in 40-lines (see [train.py](example_projects/image_classifier/train.py))
59 | * Supports Apex Amp for mixed precision training
60 |
61 |
62 | ## TODO
63 |
64 | * Support for multiple loss optimization using multiple optimizers
65 | * Support for tensorboard plots
66 | * Support for torchscript
67 | * Documentation: using metrics and losses beyond the example script
68 |
69 |
70 | # Cite
71 |
72 | If you found Pytorch Boiler useful in your project, please cite the following:
73 |
74 | ```
75 | @software{Venkat_Pytorch_Boiler,
76 | author = {Venkat, Naveen},
77 | title = {{Pytorch Boiler}},
78 | url = {https://github.com/nmakes/pytorch_boiler},
79 | year = {2022}
80 | }
81 | ```
82 |
83 | Thanks!
--------------------------------------------------------------------------------
/assets/Pytorch Boiler.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nmakes/pytorch_boiler/f60c547fec69de33e69a0689ede526f4be73bbb3/assets/Pytorch Boiler.png
--------------------------------------------------------------------------------
/example_projects/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nmakes/pytorch_boiler/f60c547fec69de33e69a0689ede526f4be73bbb3/example_projects/__init__.py
--------------------------------------------------------------------------------
/example_projects/image_autoencoder/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nmakes/pytorch_boiler/f60c547fec69de33e69a0689ede526f4be73bbb3/example_projects/image_autoencoder/__init__.py
--------------------------------------------------------------------------------
/example_projects/image_autoencoder/dataloader.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | import torchvision.datasets as tvd
3 | import torchvision.transforms as tvt
4 |
5 |
6 | def cifar10_dataloader(root='./data', train=True, batch_size=4, drop_last=True, shuffle=True, num_workers=4):
7 | if train:
8 | transforms = tvt.Compose([
9 | tvt.RandomHorizontalFlip(),
10 | tvt.RandomResizedCrop(size=32, scale=(0.08, 1)),
11 | tvt.ToTensor(),
12 | tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
13 | ])
14 | else:
15 | transforms = tvt.Compose([
16 | tvt.ToTensor(),
17 | tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
18 | ])
19 |
20 | loader = DataLoader(
21 | dataset=tvd.CIFAR10(root=root, train=train, transform=transforms, download=True),
22 | batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_last
23 | )
24 |
25 | return loader
26 |
27 |
28 | def mnist_dataloader(root='./data', train=True, batch_size=4, drop_last=True, shuffle=True, num_workers=4):
29 | if train:
30 | transforms = tvt.Compose([
31 | tvt.Resize((32, 32)),
32 | tvt.ToTensor(),
33 | tvt.Normalize(0.5, 0.5)
34 | ])
35 | else:
36 | transforms = tvt.Compose([
37 | tvt.Resize((32, 32)),
38 | tvt.ToTensor(),
39 | tvt.Normalize(0.5, 0.5)
40 | ])
41 |
42 | loader = DataLoader(
43 | dataset=tvd.MNIST(root=root, train=train, transform=transforms, download=True),
44 | batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_last
45 | )
46 |
47 | return loader
48 |
--------------------------------------------------------------------------------
/example_projects/image_autoencoder/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class GAPFlatten(nn.Module):
6 |
7 | def __init__(self):
8 | super().__init__()
9 |
10 | def forward(self, x):
11 | x = x.mean(dim=-1) # b, c, h
12 | x = x.mean(dim=-1) # b, c
13 | return x # b, c
14 |
15 |
16 | class GAPUnflatten(nn.Module):
17 |
18 | def __init__(self, output_shape=(4, 4)):
19 | super().__init__()
20 | self.h, self.w = output_shape
21 | self.linear = nn.Linear(1, self.h * self.w)
22 |
23 | def forward(self, x):
24 | b, c = x.shape[:2]
25 | x = x.reshape(b, c, 1) # (b, c, 1)
26 | x = self.linear(x) # (b, c, h*w)
27 | x = x.reshape(b, c, self.h, self.w) # (b, c, h, w)
28 | return x
29 |
30 |
31 | class TinyResNetAE(nn.Module):
32 |
33 | def __init__(self, in_channels, hidden_channels, expansion_factor=2, latent_image_size=(4, 4)):
34 | super().__init__()
35 | exf = expansion_factor
36 |
37 | self.encoder = nn.Sequential(
38 | nn.Conv2d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1),
39 | nn.BatchNorm2d(hidden_channels),
40 | nn.ReLU(), # 32
41 | nn.MaxPool2d(kernel_size=exf, stride=exf), # 16
42 |
43 | nn.Conv2d(in_channels=hidden_channels, out_channels=hidden_channels * exf, kernel_size=3, stride=1, padding=1),
44 | nn.BatchNorm2d(hidden_channels * exf),
45 | nn.ReLU(),
46 | nn.MaxPool2d(kernel_size=exf, stride=exf), # 8
47 |
48 | nn.Conv2d(in_channels=hidden_channels * exf, out_channels=hidden_channels * (exf**2), kernel_size=3, stride=1, padding=1),
49 | nn.BatchNorm2d(hidden_channels * (exf**2)),
50 | nn.ReLU(),
51 | nn.MaxPool2d(kernel_size=exf, stride=exf), # 4
52 |
53 | nn.Conv2d(in_channels=hidden_channels * (exf**2), out_channels=hidden_channels * (exf**2), kernel_size=3, stride=1, padding=1), # 4
54 | GAPFlatten(),
55 | )
56 |
57 | self.decoder = nn.Sequential(
58 | GAPUnflatten(output_shape=latent_image_size),
59 |
60 | nn.Conv2d(in_channels=hidden_channels * (exf**2), out_channels=hidden_channels * exf, kernel_size=3, stride=1, padding=1),
61 | nn.BatchNorm2d(hidden_channels * exf),
62 | nn.ReLU(),
63 | nn.Upsample(scale_factor=2, mode='bilinear'), # 8
64 |
65 | nn.Conv2d(in_channels=hidden_channels * exf, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1),
66 | nn.BatchNorm2d(hidden_channels),
67 | nn.ReLU(),
68 | nn.Upsample(scale_factor=2, mode='bilinear'), # 16
69 |
70 | nn.Conv2d(in_channels=hidden_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1),
71 | nn.BatchNorm2d(in_channels),
72 | nn.ReLU(),
73 | nn.Upsample(scale_factor=2, mode='bilinear'), # 32
74 |
75 | nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1),
76 | nn.Tanh() # Images will be normalized between [-1, 1]
77 | )
78 |
79 | def forward(self, x):
80 | encoding = self.encoder(x)
81 | decoded_image = self.decoder(encoding)
82 | return encoding, decoded_image
83 |
--------------------------------------------------------------------------------
/example_projects/image_autoencoder/train.py:
--------------------------------------------------------------------------------
1 | """Example usage of Pytorch Boiler"""
2 | import torch
3 |
4 | from pytorch_boiler import Boiler, overload
5 | from .model import TinyResNetAE
6 | from .dataloader import mnist_dataloader, cifar10_dataloader
7 |
8 |
9 | class Trainer(Boiler):
10 |
11 | @overload
12 | def pre_process(self, data):
13 | images, labels = data
14 | return images.cuda()
15 |
16 | @overload
17 | def loss(self, model_output, data): # Overload the loss function
18 | images, labels = data
19 | latent_encoding, predicted_images = model_output
20 | l2_loss = ((images.cuda() - predicted_images) ** 2).mean()
21 | return l2_loss # Can return a tensor, or a dictinoary like {'xe_loss': xe_loss} with multiple losses. See README.
22 |
23 | @overload
24 | def performance(self, model_output, data):
25 | images, labels = data
26 | latent_encoding, predicted_images = model_output
27 | diff = (images.cuda() - predicted_images) ** 2
28 | reconstruction_error = diff.mean()
29 | max_error = diff.mean(dim=(1, 2, 3)).amax(dim=0)
30 | min_error = diff.mean(dim=(1, 2, 3)).amin(dim=0)
31 | return_values = {
32 | 'avg_reconstruction_error': reconstruction_error,
33 | 'avg_max_error': max_error,
34 | 'avg_min_error': min_error,
35 | 'summary': reconstruction_error
36 | }
37 | return return_values # Can return a tensor, or a dictinoary like {'acc': acc} with multiple metrics. See README.
38 |
39 |
40 | if __name__ == '__main__':
41 | dataloader = mnist_dataloader; dataset = 'mnist'; in_channels = 1; batch_size = 256; exp_tag='autoencoder_mnist'
42 | # dataloader = cifar10_dataloader; dataset = 'cifar10'; in_channels = 3; batch_size = 256; exp_tag='autoencoder_cifar10'
43 | model = TinyResNetAE(in_channels=in_channels, hidden_channels=4, expansion_factor=2, latent_image_size=(4, 4)).cuda()
44 | optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3, weight_decay=1e-5)
45 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10000, gamma=1)
46 | train_dataloader = dataloader(root=f'./data/{dataset}', batch_size=batch_size, train=True, shuffle=True, drop_last=True)
47 | val_dataloader = dataloader(root=f'./data/{dataset}', batch_size=batch_size, train=False, shuffle=False, drop_last=False)
48 |
49 | trainer = Trainer(model=model, optimizer=optimizer, scheduler=scheduler,
50 | train_dataloader=train_dataloader, val_dataloader=val_dataloader,
51 | epochs=10, save_path=f'./model/{exp_tag}/state_dict.pt', load_path=None, mixed_precision=True)
52 | trainer.fit()
53 |
--------------------------------------------------------------------------------
/example_projects/image_classifier/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nmakes/pytorch_boiler/f60c547fec69de33e69a0689ede526f4be73bbb3/example_projects/image_classifier/__init__.py
--------------------------------------------------------------------------------
/example_projects/image_classifier/dataloader.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | import torchvision.datasets as tvd
3 | import torchvision.transforms as tvt
4 |
5 |
6 | def cifar10_dataloader(root='./data', train=True, batch_size=4, drop_last=True, shuffle=True, num_workers=4):
7 | if train:
8 | transforms = tvt.Compose([
9 | tvt.RandomHorizontalFlip(),
10 | tvt.RandomResizedCrop(size=32, scale=(0.08, 1)),
11 | tvt.ToTensor(),
12 | tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
13 | ])
14 | else:
15 | transforms = tvt.Compose([
16 | tvt.ToTensor(),
17 | tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
18 | ])
19 |
20 | loader = DataLoader(
21 | dataset=tvd.CIFAR10(root=root, train=train, transform=transforms, download=True),
22 | batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_last
23 | )
24 |
25 | return loader
26 |
27 |
28 | def mnist_dataloader(root='./data', train=True, batch_size=4, drop_last=True, shuffle=True, num_workers=4):
29 | if train:
30 | transforms = tvt.Compose([
31 | tvt.ToTensor(),
32 | tvt.Normalize(0.5, 0.5)
33 | ])
34 | else:
35 | transforms = tvt.Compose([
36 | tvt.ToTensor(),
37 | tvt.Normalize(0.5, 0.5)
38 | ])
39 |
40 | loader = DataLoader(
41 | dataset=tvd.MNIST(root=root, train=train, transform=transforms, download=True),
42 | batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_last
43 | )
44 |
45 | return loader
46 |
--------------------------------------------------------------------------------
/example_projects/image_classifier/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class TinyResNet(nn.Module):
6 |
7 | def __init__(self, in_channels, hidden_channels, output_channels, num_layers, expansion_factor=2):
8 | super().__init__()
9 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1)
10 |
11 | self.conv_blocks = nn.ModuleList([
12 | nn.Sequential(
13 | nn.Conv2d(hidden_channels * (2**i), hidden_channels * (2**i), kernel_size=3, stride=1, padding=1),
14 | nn.BatchNorm2d(hidden_channels * (2**i)),
15 | nn.ReLU(),
16 | nn.Conv2d(hidden_channels * (2**i), hidden_channels * (2**(i+1)), kernel_size=3, stride=1, padding=1),
17 | nn.BatchNorm2d(hidden_channels * (2**(i+1))),
18 | nn.ReLU(),
19 | nn.Conv2d(hidden_channels * (2**(i+1)), hidden_channels * (2**(i+1)), kernel_size=3, stride=1, padding=1),
20 | nn.BatchNorm2d(hidden_channels * (2**(i+1))),
21 | nn.ReLU(),
22 | )
23 | for i in range(num_layers)
24 | ])
25 |
26 | self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
27 |
28 | self.skip_connections = nn.ModuleList([
29 | nn.Conv2d(hidden_channels * (2**i), hidden_channels * (2**(i+1)), kernel_size=3, stride=1, padding=1)
30 | for i in range(num_layers)
31 | ])
32 |
33 | self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
34 |
35 | self.fc = nn.Sequential(
36 | nn.Linear(hidden_channels * (2**(num_layers)), hidden_channels * (2**(num_layers))),
37 | nn.ReLU(),
38 | nn.Linear(hidden_channels * (2**(num_layers)), output_channels)
39 | )
40 |
41 | def forward(self, x):
42 | x = self.conv1(x) # (B, C, H, W)
43 |
44 | for conv_block, skip_conn in zip(self.conv_blocks, self.skip_connections):
45 | block_out = conv_block(x)
46 | skip_out = skip_conn(x)
47 | x = block_out + skip_out
48 | x = self.max_pool(x)
49 |
50 | x = self.global_avg_pool(x) # (B, C, 1, 1)
51 | x = x.squeeze(-1).squeeze(-1)
52 |
53 | x = self.fc(x)
54 | return x
55 |
--------------------------------------------------------------------------------
/example_projects/image_classifier/train.py:
--------------------------------------------------------------------------------
1 | """Example usage of Pytorch Boiler"""
2 | import torch
3 |
4 | from pytorch_boiler import Boiler, overload
5 | from .model import TinyResNet
6 | from .dataloader import mnist_dataloader, cifar10_dataloader
7 |
8 |
9 | class Trainer(Boiler):
10 |
11 | @overload
12 | def pre_process(self, data):
13 | images, labels = data
14 | return images.cuda()
15 |
16 | @overload
17 | def loss(self, model_output, data): # Overload the loss function
18 | images, labels = data
19 | xe_loss = torch.nn.functional.cross_entropy(model_output, labels.cuda(), reduction='mean')
20 | return xe_loss # Can return a tensor, or a dictinoary like {'xe_loss': xe_loss} with multiple losses. See README.
21 |
22 | @overload
23 | def performance(self, model_output, data):
24 | images, labels = data
25 | preds = model_output.argmax(dim=-1)
26 | acc = (preds == labels.cuda()).float().mean()
27 | return acc.cpu().detach().numpy() # Can return a tensor, or a dictinoary like {'acc': acc} with multiple metrics. See README.
28 |
29 |
30 | if __name__ == '__main__':
31 | # dataloader = mnist_dataloader; dataset = 'mnist'; in_channels = 1; batch_size = 256; exp_tag='classifier_mnist'
32 | dataloader = cifar10_dataloader; dataset = 'cifar10'; in_channels = 3; batch_size = 256; exp_tag='classifier_cifar10'
33 | model = TinyResNet(in_channels=in_channels, hidden_channels=4, output_channels=10, num_layers=3, expansion_factor=2).cuda()
34 | optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4)
35 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10000, gamma=1)
36 | train_dataloader = dataloader(root=f'./data/{dataset}', batch_size=batch_size, train=True, shuffle=True, drop_last=True)
37 | val_dataloader = dataloader(root=f'./data/{dataset}', batch_size=batch_size, train=False, shuffle=False, drop_last=False)
38 |
39 | trainer = Trainer(model=model, optimizer=optimizer, scheduler=scheduler,
40 | train_dataloader=train_dataloader, val_dataloader=val_dataloader,
41 | epochs=10, save_path=f'./model/{exp_tag}/state_dict.pt', load_path=None, mixed_precision=True)
42 | trainer.fit()
43 |
--------------------------------------------------------------------------------
/pytorch_boiler/__init__.py:
--------------------------------------------------------------------------------
1 | from .boiler import (
2 | Boiler
3 | )
4 |
5 | from .utils import (
6 | overload
7 | )
8 |
9 | from .tracker import (
10 | Tracker
11 | )
12 |
13 | __all__ = [
14 | # Boiler imports
15 | 'Boiler',
16 |
17 | # Util imports
18 | 'overload',
19 |
20 | # Tracker imports
21 | 'Tracker'
22 | ]
23 |
--------------------------------------------------------------------------------
/pytorch_boiler/boiler.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 | from tqdm import tqdm
8 | from typing import Any
9 |
10 | from .utils import init_overload_state, is_method_overloaded, prettify_dict
11 | from .tracker import Tracker
12 |
13 |
14 | class Boiler(nn.Module):
15 |
16 | def __init__(self, model, optimizer, scheduler, train_dataloader, val_dataloader, epochs, patience=None, save_path=None, load_path=None, mixed_precision=False):
17 | super(Boiler, self).__init__()
18 | self.model = model
19 | self.optimizer = optimizer
20 | self.scheduler = scheduler
21 | self.train_dataloader = train_dataloader
22 | self.val_dataloader = val_dataloader
23 | self.epochs = epochs
24 | self.patience = patience
25 |
26 | self.save_path = save_path
27 | self.load_path = load_path
28 |
29 | self.mixed_precision = mixed_precision
30 |
31 | # Initialize tracker
32 | self.tracker = Tracker()
33 | self.patience_counter = 0
34 |
35 | # Initialize mixed precision
36 | self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision)
37 |
38 | # Load model state
39 | self.start_epoch = 0
40 | self.best_validation_loss = float('inf')
41 | if self.load_path is not None:
42 | print('\nBoiler | Loading from {}'.format(self.load_path))
43 | loaded_object = torch.load(self.load_path)
44 | self.model.load_state_dict(loaded_object['model_state_dict'])
45 | self.optimizer.load_state_dict(loaded_object['optimizer_state_dict'])
46 | self.scaler.load_state_dict(loaded_object['scaler_state_dict'])
47 | self.scheduler.load_state_dict(loaded_object['scheduler_state_dict'])
48 | self.tracker.load_state_dict(loaded_object['tracker'])
49 | self.start_epoch = loaded_object['epoch'] + 1
50 | self.best_validation_loss = loaded_object['best_validation_loss']
51 |
52 | @init_overload_state
53 | def pre_process(self, x):
54 | """Computes the pre processing steps. Must be overloaded for usage.
55 |
56 | Args:
57 | x (Any): The input to the pre processor
58 |
59 | Raises:
60 | NotImplementedError: If this function is not implemented and is called.
61 | """
62 | raise NotImplementedError()
63 |
64 | @init_overload_state
65 | def post_process(self, x):
66 | """Computes the post processing steps. Must be overloaded for usage.
67 |
68 | Args:
69 | x (Any): The input to the post processor (output from the model forward pass).
70 |
71 | Raises:
72 | NotImplementedError: If this function is not implemented and is called.
73 | """
74 | raise NotImplementedError()
75 |
76 | @init_overload_state
77 | def forward(self, x):
78 | """Computes the forward pass, with additional pre and post processing.
79 |
80 | Args:
81 | x (torch.Tensor|Any): The tensor input to the model forward pass, or input to the pre processor.
82 |
83 | Returns:
84 | torch.Tensor|Any: The tensor output of the model, or the output from the post processor.
85 | """
86 | if is_method_overloaded(self.pre_process):
87 | x = self.pre_process(x)
88 | with torch.autocast('cuda', dtype=torch.float16, enabled=(self.mixed_precision and self.training)):
89 | x = self.model(x)
90 | if is_method_overloaded(self.post_process):
91 | x = self.post_process(x)
92 | return x
93 |
94 | @init_overload_state
95 | def loss(self, output, data):
96 | """Computes the loss. Must be overloaded for usage.
97 |
98 | Args:
99 | output(torch.Tensor|Any): The output from the model forward pass, or output from the post procesor.
100 | data(torch.Tensor|Any): The data batch loaded from the dataloader.
101 |
102 | Raises:
103 | NotImplementedError: If this function is not implemented and is called.
104 | """
105 | raise NotImplementedError()
106 |
107 | @init_overload_state
108 | def performance(self, output, data):
109 | """Computes performance metric. Must be overloaded for usage.
110 |
111 | Args:
112 | output (torch.Tensor|Any): The output from the model forward pass, or output from the post processor.
113 | data (torch.Tensor|Any): The data batch loaded from the dataloader.
114 |
115 | Returns:
116 | torch.Tensor|Any: The performance metric computed on the current batch.
117 | """
118 | raise NotImplementedError()
119 |
120 | @init_overload_state
121 | def infer(self, x):
122 | """Computes the inference in eval mode.
123 |
124 | Args:
125 | x (torch.Tensor|Any): The tensor input to the model forward pass, or input to the pre processor.
126 |
127 | Returns:
128 | torch.Tensor|Any: The tensor output of the model, or the output from the post processor.
129 | """
130 | # If model is in training mode, convert to eval
131 | training_state_true = self.training
132 | if training_state_true:
133 | self.eval()
134 |
135 | with torch.no_grad():
136 | output = self.forward(x)
137 |
138 | # If model was in training mode, convert back to train from eval
139 | if training_state_true:
140 | self.train()
141 |
142 | return output
143 |
144 | @init_overload_state
145 | def decode_item_type(self, item, require_summary=True):
146 | """Decodes the type of the item. This function is used to handle multiple metrics
147 |
148 | Args:
149 | item (torch.Tensor|np.ndarray|dict): Either a tensor describing a loss/metric, or a dictionary of named tensors.
150 | require_summary (bool): If True then the item of type dictionary will be enforced to have a key 'summary' if more than one keys are present.
151 |
152 | Returns:
153 | dict: A dictionary with named tensors, where "summary" is a special key that specifies the loss to be backpropagated / metric to be tracked.
154 | """
155 | assert type(item) in [torch.Tensor, np.ndarray, dict], f"Either a tensor, numpy array, or a dictionary of (tag: value) must be passed. Given {type(item)}."
156 | if type(item) == torch.Tensor or type(item) == np.ndarray:
157 | return {'summary': item}
158 | elif type(item) == dict:
159 | keys = list(item.keys())
160 | if len(keys) == 1:
161 | item['summary'] = item[keys[0]]
162 | if require_summary:
163 | assert 'summary' in item.keys(), "If a dictionary is returned from loss / performance, it must contain the key called 'summary' indicating the loss to backpropagate / metric to track."
164 | return item
165 |
166 | @init_overload_state
167 | def train_epoch(self):
168 | """Trains the model for one epoch.
169 |
170 | Args:
171 | None
172 |
173 | Returns:
174 | Tracker: The tracker object.
175 | """
176 | self.train() # Set the mode to training
177 | for data in tqdm(self.train_dataloader):
178 | # Set optimizer zero grad
179 | self.optimizer.zero_grad()
180 |
181 | # Compute model output
182 | model_output = self.forward(data)
183 |
184 | # Compute loss & Update tracker
185 | loss = self.loss(model_output, data)
186 | decoded_loss = self.decode_item_type(loss, require_summary=True)
187 | self.tracker.update(key='training_loss',
188 | value=decoded_loss['summary'].cpu().detach().numpy())
189 | for key in decoded_loss:
190 | if key != 'summary':
191 | self.tracker.update(key='training_{}'.format(key),
192 | value=decoded_loss[key].cpu().detach().numpy())
193 |
194 | # Compute performance and update tracker
195 | if is_method_overloaded(self.performance):
196 | perf = self.performance(model_output, data)
197 | decoded_perf = self.decode_item_type(perf, require_summary=True)
198 | self.tracker.update(key='training_perf',
199 | value=decoded_perf['summary'].cpu().detach().numpy() \
200 | if type(decoded_perf['summary'])==torch.Tensor \
201 | else decoded_perf['summary'])
202 | for key in decoded_perf:
203 | if key != 'summary':
204 | self.tracker.update(key='training_{}'.format(key),
205 | value=decoded_perf[key].cpu().detach().numpy() \
206 | if type(decoded_perf[key])==torch.Tensor \
207 | else decoded_perf[key])
208 |
209 | if self.mixed_precision:
210 | # Scale the loss
211 | scaled_loss = self.scaler.scale(decoded_loss['summary'])
212 | # Backpropagate loss
213 | scaled_loss.backward()
214 | # Update model parameters
215 | self.scaler.step(self.optimizer)
216 | self.scaler.update()
217 | else:
218 | # Backpropagate loss
219 | decoded_loss['summary'].backward()
220 | # Update model parameters
221 | self.optimizer.step()
222 | return self.tracker
223 |
224 | @init_overload_state
225 | def eval_epoch(self):
226 | """Evaluates the model for one epoch.
227 |
228 | Args:
229 | None
230 |
231 | Returns:
232 | Tracker: The tracker object.
233 | """
234 | self.eval() # Set the mode to evaluation
235 | for data in tqdm(self.val_dataloader):
236 | # Compute model output
237 | model_output = self.infer(data)
238 |
239 | # Compute loss & Update tracker
240 | loss = self.loss(model_output, data)
241 | decoded_loss = self.decode_item_type(loss, require_summary=True)
242 | self.tracker.update('validation_loss', decoded_loss['summary'].cpu().detach().numpy())
243 | for key in decoded_loss:
244 | if key != 'summary':
245 | self.tracker.update('validation_{}'.format(key), decoded_loss[key].cpu().detach().numpy())
246 |
247 | # Compute performance & Update tracker
248 | if is_method_overloaded(self.performance):
249 | perf = self.performance(model_output, data)
250 | decoded_perf = self.decode_item_type(perf, require_summary=True)
251 | self.tracker.update('validation_perf', decoded_perf['summary'].cpu().detach().numpy() if type(decoded_perf['summary'])==torch.Tensor else decoded_perf['summary'])
252 | for key in decoded_perf:
253 | if key != 'summary':
254 | self.tracker.update('validation_{}'.format(key), decoded_perf[key].cpu().detach().numpy() if type(decoded_perf[key])==torch.Tensor else decoded_perf[key])
255 |
256 | return self.tracker
257 |
258 | @init_overload_state
259 | def fit(self):
260 | """Fits the model for the given number of epochs.
261 |
262 | Returns:
263 | Boiler: Self.
264 | """
265 | best_validation_loss = self.best_validation_loss
266 | start_epoch = self.start_epoch
267 |
268 | for e in range(start_epoch, self.epochs):
269 | print('\nBoiler | Training epoch {}/{}...'.format(e+1, self.epochs))
270 | self.train_epoch()
271 | self.eval_epoch()
272 | self.scheduler.step()
273 | summary = self.tracker.summarize()
274 | print(prettify_dict(summary))
275 | self.tracker.stash()
276 |
277 | if summary['validation_loss'] < best_validation_loss:
278 | self.patience_counter = 0
279 | to_save = True
280 | else:
281 | self.patience_counter += 1
282 | to_save = False
283 |
284 | if self.save_path is not None and to_save:
285 | best_validation_loss = summary['validation_loss']
286 | print('Saving training state at {}'.format(self.save_path))
287 | os.makedirs(os.path.abspath(os.path.dirname(self.save_path)), exist_ok=True)
288 | saved_object = {
289 | 'model_state_dict': self.model.state_dict(),
290 | 'optimizer_state_dict': self.optimizer.state_dict(),
291 | 'scaler_state_dict': self.scaler.state_dict(),
292 | 'scheduler_state_dict': self.scheduler.state_dict(),
293 | 'tracker': self.tracker.state_dict(),
294 | 'epoch': e,
295 | 'best_validation_loss': best_validation_loss
296 | }
297 | torch.save(saved_object, self.save_path)
298 |
299 | if (self.patience is not None) and (self.patience_counter > self.patience):
300 | print('Maximum patience reached. Stopping training.')
301 | break
302 |
303 | return self
304 |
--------------------------------------------------------------------------------
/pytorch_boiler/tracker.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | import json
4 |
5 |
6 | class Tracker:
7 |
8 | def __init__(self):
9 | self.history = {}
10 | self.running_history = []
11 |
12 | def update(self, key, value):
13 | if key not in self.history:
14 | self.history[key] = []
15 | self.history[key].append(value)
16 |
17 | def stash(self):
18 | self.running_history.append(self.history)
19 | self.history = {}
20 |
21 | def summarize(self):
22 | summary = {}
23 | for key in self.history:
24 | summary[key] = np.mean(self.history[key]).item()
25 | return summary
26 |
27 | def state_dict(self):
28 | state = {
29 | 'history': self.history,
30 | 'running_history': self.running_history
31 | }
32 | return state
33 |
34 | def load_state_dict(self, state_dict):
35 | self.history = state_dict['history']
36 | self.running_history = state_dict['running_history']
37 | return self
--------------------------------------------------------------------------------
/pytorch_boiler/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 |
4 | # -----------------------------
5 | # Utilities to override methods
6 | # -----------------------------
7 | def overload(func):
8 | func.is_overloaded = True
9 | return func
10 |
11 | def is_method_overloaded(func):
12 | if not hasattr(func, 'is_overloaded'):
13 | return False
14 | else:
15 | return func.is_overloaded
16 |
17 | def init_overload_state(func):
18 | if not is_method_overloaded(func):
19 | func.is_overloaded = False
20 | return func
21 |
22 |
23 | # ---------------
24 | # Text formatting
25 | # ---------------
26 | def prettify_dict(d):
27 | return json.dumps(d, indent=2)
28 |
--------------------------------------------------------------------------------