├── README.md ├── requirements.txt ├── src ├── dataset.py ├── model.py ├── run.py └── trainer.py └── tests ├── templates.py ├── test_dataset.py ├── test_model.py └── test_trainer.py /README.md: -------------------------------------------------------------------------------- 1 | # How to Trust Your Deep Learning Code 2 | 3 | This is the support repository for the blog post [How to Trust Your Deep Learning Code](http://krokotsch.eu/cleancode/2020/08/11/Unit-Tests-for-Deep-Learning.html). 4 | It contains code for training a Variational Autoencoder (VAE) and the associated unit tests. 5 | The unit tests illustrate useful concepts to test in deep learning projects. 6 | The focus lay on writing tests that are readable and reusable. 7 | 8 | For more information check out the blog post. 9 | 10 | ## Usage 11 | 12 | The project uses Python 3.7. 13 | First, install the packages specified in the `requirements.txt` file. 14 | 15 | ``` 16 | conda create -n unittest_dl python=3.7 17 | conda activate unittest_dl 18 | conda install --file requirements.txt -c pytorch 19 | 20 | ## or 21 | 22 | virtualenv -p python3.7 unittest_dl 23 | source unittest_dl/bin/activate 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | This project was developed in PyCharm, so it is the easiest to use that way. 28 | Open it in the IDE and mark the `src` directory as Sources Root (right-click the folder > Mark directory as > Sources Root). 29 | Everything should work out of the box now. 30 | To run all tests, right-click the `tests` directory and select `"Run 'Unittests in tests'"`. 31 | 32 | As an alternative, you can manually add `src` to your `PYTHON_PATH` environment variable. 33 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch=1.5.1 2 | scipy=1.5.0 3 | tensorboard=2.2.1 4 | torchvision=0.6.1 5 | tqdm=4.47.0 6 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torchvision 4 | import torchvision.transforms as forms 5 | 6 | 7 | class MyMNIST: 8 | def __init__(self): 9 | script_path = os.path.dirname(__file__) 10 | data_root = os.path.join(script_path, '..', 'data') 11 | os.makedirs(data_root, exist_ok=True) 12 | 13 | # Pad to 32x32, augment and scale to [-1, 1] 14 | train_transforms = forms.Compose([forms.Pad(2), 15 | forms.RandomRotation(5), 16 | forms.ToTensor(), 17 | forms.Lambda(lambda x: 2*x-1)]) 18 | # Pad to 32x32 and scale to [-1, 1] 19 | test_transforms = forms.Compose([forms.Pad(2), 20 | forms.ToTensor(), 21 | forms.Lambda(lambda x: 2 * x - 1)]) 22 | 23 | self.train_data = torchvision.datasets.MNIST(data_root, 24 | train=True, 25 | transform=train_transforms, 26 | download=True) 27 | self.test_data = torchvision.datasets.MNIST(data_root, 28 | train=False, 29 | transform=test_transforms, 30 | download=True) 31 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class VAE(nn.Module): 6 | def __init__(self, bottleneck_dim): 7 | super(VAE, self).__init__() 8 | 9 | self.bottleneck_dim = bottleneck_dim 10 | self.encoder = None 11 | self.decoder = None 12 | 13 | def forward(self, inputs): 14 | mu, log_sigma = self.encode(inputs) 15 | latent_code = self.bottleneck(mu, log_sigma) 16 | outputs = self.decode(latent_code) 17 | 18 | return outputs 19 | 20 | def encode(self, inputs): 21 | latent_parameters = self.encoder(inputs) 22 | mu, log_sigma = torch.split(latent_parameters, self.bottleneck_dim, dim=1) 23 | 24 | return mu, log_sigma 25 | 26 | def bottleneck(self, mu, log_sigma): 27 | noise = torch.randn_like(mu) 28 | latent_code = log_sigma.exp() * noise + mu 29 | 30 | return latent_code 31 | 32 | def decode(self, latent_code): 33 | return self.decoder(latent_code) 34 | 35 | 36 | class CNNVAE(VAE): 37 | def __init__(self, input_shape, bottleneck_dim): 38 | super(CNNVAE, self).__init__(bottleneck_dim) 39 | 40 | in_channels = input_shape[0] 41 | hw = input_shape[1] 42 | hw_before_linear = hw // 4 43 | flat_dim = 64 * hw_before_linear ** 2 44 | 45 | self.encoder = nn.Sequential(nn.Conv2d(in_channels, out_channels=16, kernel_size=5, padding=2), 46 | nn.ReLU(True), 47 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), 48 | nn.ReLU(True), 49 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), 50 | nn.ReLU(True), 51 | nn.Flatten(), 52 | nn.Linear(flat_dim, 2*bottleneck_dim)) 53 | 54 | self.decoder = nn.Sequential(nn.Linear(bottleneck_dim, flat_dim), 55 | nn.ReLU(True), 56 | Unflatten((64, hw_before_linear, hw_before_linear)), 57 | nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, 58 | padding=1, output_padding=1), 59 | nn.ReLU(True), 60 | nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2, 61 | padding=1, output_padding=1), 62 | nn.ReLU(True), 63 | nn.Conv2d(in_channels=16, out_channels=in_channels, kernel_size=5, padding=2), 64 | nn.Tanh()) 65 | 66 | 67 | class MLPVAE(VAE): 68 | def __init__(self, input_shape, bottleneck_dim): 69 | super(MLPVAE, self).__init__(bottleneck_dim) 70 | 71 | in_channels = input_shape[0] 72 | hw = input_shape[1] 73 | flat_dim = in_channels * (hw ** 2) 74 | 75 | self.encoder = nn.Sequential(nn.Flatten(), 76 | nn.Linear(flat_dim, out_features=512), 77 | nn.ReLU(True), 78 | nn.Linear(512, 256), 79 | nn.ReLU(True), 80 | nn.Linear(256, 128), 81 | nn.ReLU(True), 82 | nn.Linear(128, 2*bottleneck_dim)) 83 | 84 | self.decoder = nn.Sequential(nn.Linear(bottleneck_dim, 128), 85 | nn.ReLU(True), 86 | nn.Linear(128, 256), 87 | nn.ReLU(True), 88 | nn.Linear(256, 512), 89 | nn.ReLU(True), 90 | nn.Linear(512, flat_dim), 91 | Unflatten(input_shape), 92 | nn.Tanh()) 93 | 94 | 95 | class Unflatten(nn.Module): 96 | def __init__(self, shape): 97 | """ 98 | Reshapes a batch of flat tensors to the given shape. 99 | 100 | :param shape: expected output shape without batch dimension 101 | """ 102 | super(Unflatten, self).__init__() 103 | 104 | self.shape = shape 105 | 106 | def forward(self, inputs): 107 | return torch.reshape(inputs, (-1,) + self.shape) 108 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import dataset 4 | import model 5 | import trainer 6 | 7 | 8 | def run(network_type, bottleneck_dim, lr, batch_size, epochs, device, log_dir): 9 | mnist_data = dataset.MyMNIST() 10 | 11 | if network_type == 'mlp': 12 | net = model.MLPVAE((1, 32, 32), bottleneck_dim) 13 | elif network_type == 'cnn': 14 | net = model.CNNVAE((1, 32, 32), bottleneck_dim) 15 | else: 16 | raise ValueError(f'Unsupported network type {network_type}. Chose between "mlp" and "cnn".') 17 | 18 | optim = torch.optim.Adam(net.parameters(), lr) 19 | vae_trainer = trainer.Trainer(net, mnist_data, optim, batch_size, device, log_dir) 20 | vae_trainer.train(epochs) 21 | 22 | 23 | if __name__ == '__main__': 24 | import argparse 25 | parser = argparse.ArgumentParser(description='Run the training for a VAE.') 26 | parser.add_argument('-t', '--network_type', required=True, choices=['mlp', 'cnn'], help='type of the VAE network') 27 | parser.add_argument('-n', '--bottleneck_dim', default=16, type=int, help='size of the VAE bottleneck') 28 | parser.add_argument('-r', '--lr', default=0.001, type=float, help='learning rate for training') 29 | parser.add_argument('-b', '--batch_size', required=True, type=int, help='batch size for training') 30 | parser.add_argument('-e', '--epochs', required=True, type=int, help='epochs to train') 31 | parser.add_argument('-d', '--device', default='cpu', help='device to train on, e.g. "cuda:0"') 32 | parser.add_argument('-l', '--logdir', default='./results', help='directory to log the models and event file to') 33 | opt = parser.parse_args() 34 | 35 | run(opt.network_type, opt.bottleneck_dim, opt.lr, opt.batch_size, opt.epochs, opt.device, opt.logdir) 36 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tqdm 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader, Subset 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | 10 | class Trainer: 11 | def __init__(self, model, data, optimizer, batch_size, device, log_dir='./results', num_generated_images=10): 12 | self.model = model.to(device) 13 | self.data = data 14 | self.optimizer = optimizer 15 | self.device = device 16 | self.log_dir = log_dir 17 | self._num_generated_images = num_generated_images 18 | 19 | self._epoch = 0 20 | self._step = 0 21 | self._train_data = DataLoader(self.data.train_data, batch_size, shuffle=True, num_workers=2) 22 | self._test_data = DataLoader(self.data.test_data, batch_size, shuffle=False, num_workers=2) 23 | self._progress_bar = None 24 | 25 | self.summary = SummaryWriter(log_dir) 26 | 27 | def train(self, epochs): 28 | self._epoch = 0 29 | self._step = 0 30 | for e in range(epochs): 31 | self._epoch = e 32 | self._train_epoch() 33 | self._eval_and_log() 34 | self._save_model() 35 | 36 | def _train_epoch(self): 37 | self._progress_bar = tqdm.tqdm(self._train_data) 38 | self._progress_bar.set_description(f'Epoch {self._epoch}') 39 | 40 | self.model.train() 41 | for batch in self._progress_bar: 42 | self._train_step(batch) 43 | 44 | def _train_step(self, batch): 45 | self.optimizer.zero_grad() 46 | kl_div_loss, recon_loss = self._calc_loss(batch) 47 | loss = recon_loss + kl_div_loss 48 | loss.backward() 49 | self.optimizer.step() 50 | 51 | self._progress_bar.set_postfix({'loss': loss.item()}) 52 | self.summary.add_scalar('train/recon_loss', recon_loss.item(), self._step) 53 | self.summary.add_scalar('train/kl_div_loss', kl_div_loss.item(), self._step) 54 | self.summary.add_scalar('train/loss', loss.item(), self._step) 55 | self._step += 1 56 | 57 | def _calc_loss(self, batch): 58 | inputs, _ = batch 59 | mu, log_sigma = self.model.encode(inputs) 60 | latent_code = self.model.bottleneck(mu, log_sigma) 61 | outputs = self.model.decode(latent_code) 62 | recon_loss = F.mse_loss(outputs, inputs, reduction='sum') 63 | kl_div_loss = self._kl_divergence(log_sigma, mu) 64 | 65 | return kl_div_loss, recon_loss 66 | 67 | @staticmethod 68 | def _kl_divergence(log_sigma, mu): 69 | return 0.5 * torch.sum((2 * log_sigma).exp() + mu ** 2 - 1 - 2 * log_sigma) 70 | 71 | @torch.no_grad() 72 | def eval(self): 73 | self.model.eval() 74 | 75 | eval_loss = 0. 76 | for batch in self._test_data: 77 | eval_loss += self._eval_step(batch) 78 | eval_loss /= len(self.data.test_data) 79 | 80 | return eval_loss 81 | 82 | def generate(self, n): 83 | self.model.eval() 84 | 85 | samples = torch.stack([self.data.test_data[i][0] for i in range(n)]) 86 | images = self.model(samples) 87 | images = torch.cat([samples, images], dim=3) 88 | 89 | return images 90 | 91 | def _eval_step(self, batch): 92 | _, recon_loss = self._calc_loss(batch) 93 | 94 | return recon_loss.item() 95 | 96 | def _eval_and_log(self): 97 | print(f'Evaluate epoch {self._epoch}: ', end='') 98 | eval_loss = self.eval() 99 | print(eval_loss) 100 | self.summary.add_scalar('test/loss', eval_loss, self._epoch) 101 | 102 | images = self.generate(n=self._num_generated_images) 103 | for i, img in enumerate(images): 104 | self.summary.add_image(f'test/{i}', img, global_step=self._epoch) 105 | 106 | def _save_model(self): 107 | save_path = os.path.join(self.log_dir, f'model_{str(self._epoch).zfill(3)}.pth') 108 | torch.save(self.model.cpu(), save_path) 109 | -------------------------------------------------------------------------------- /tests/templates.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class ModelTestsMixin: 8 | @torch.no_grad() 9 | def test_shape(self): 10 | outputs = self.net(self.test_inputs) 11 | self.assertEqual(self.test_inputs.shape, outputs.shape) 12 | 13 | @torch.no_grad() 14 | @unittest.skipUnless(torch.cuda.is_available(), 'No GPU was detected') 15 | def test_device_moving(self): 16 | net_on_gpu = self.net.to('cuda:0') 17 | net_back_on_cpu = net_on_gpu.cpu() 18 | 19 | torch.manual_seed(42) 20 | outputs_cpu = self.net(self.test_inputs) 21 | torch.manual_seed(42) 22 | outputs_gpu = net_on_gpu(self.test_inputs.to('cuda:0')) 23 | torch.manual_seed(42) 24 | outputs_back_on_cpu = net_back_on_cpu(self.test_inputs) 25 | 26 | self.assertAlmostEqual(0., torch.sum(outputs_cpu - outputs_gpu.cpu())) 27 | self.assertAlmostEqual(0., torch.sum(outputs_cpu - outputs_back_on_cpu)) 28 | 29 | def test_batch_independence(self): 30 | inputs = self.test_inputs.clone() 31 | inputs.requires_grad = True 32 | 33 | # Compute forward pass in eval mode to deactivate batch norm 34 | self.net.eval() 35 | outputs = self.net(inputs) 36 | self.net.train() 37 | 38 | # Mask loss for certain samples in batch 39 | batch_size = inputs[0].shape[0] 40 | mask_idx = torch.randint(0, batch_size, ()) 41 | mask = torch.ones_like(outputs) 42 | mask[mask_idx] = 0 43 | outputs = outputs * mask 44 | 45 | # Compute backward pass 46 | loss = outputs.mean() 47 | loss.backward() 48 | 49 | # Check if gradient exists and is zero for masked samples 50 | for i, grad in enumerate(inputs.grad): 51 | if i == mask_idx: 52 | self.assertTrue(torch.all(grad == 0).item()) 53 | else: 54 | self.assertTrue(not torch.all(grad == 0)) 55 | 56 | def test_all_parameters_updated(self): 57 | optim = torch.optim.SGD(self.net.parameters(), lr=0.1) 58 | 59 | outputs = self.net(self.test_inputs) 60 | loss = outputs.mean() 61 | loss.backward() 62 | optim.step() 63 | 64 | for param_name, param in self.net.named_parameters(): 65 | if param.requires_grad: 66 | with self.subTest(name=param_name): 67 | self.assertIsNotNone(param.grad) 68 | self.assertNotEqual(0., torch.sum(param.grad ** 2)) 69 | 70 | 71 | class DatasetTestsMixin: 72 | def test_shape(self): 73 | with self.subTest(split='train'): 74 | self._check_shape(self.data.train_data) 75 | with self.subTest(split='test'): 76 | self._check_shape(self.data.test_data) 77 | 78 | def _check_shape(self, dataset): 79 | sample, _ = dataset[0] 80 | self.assertEqual(self.data_shape, sample.shape) 81 | 82 | def test_scaling(self): 83 | with self.subTest(split='train'): 84 | self._check_scaling(self.data.train_data) 85 | with self.subTest(split='test'): 86 | self._check_scaling(self.data.test_data) 87 | 88 | def _check_scaling(self, data): 89 | for sample, _ in data: 90 | # Values are in range [-1, 1] 91 | self.assertGreaterEqual(1, sample.max()) 92 | self.assertLessEqual(-1, sample.min()) 93 | # Values are not only covering [0, 1] or [-1, 0] 94 | self.assertTrue(torch.any(sample < 0)) 95 | self.assertTrue(torch.any(sample > 0)) 96 | 97 | def test_augmentation(self): 98 | with self.subTest(split='train'): 99 | self._check_augmentation(self.data.train_data, active=True) 100 | with self.subTest(split='test'): 101 | self._check_augmentation(self.data.test_data, active=False) 102 | 103 | def _check_augmentation(self, data, active): 104 | are_same = [] 105 | for i in range(len(data)): 106 | sample_1, _ = data[i] 107 | sample_2, _ = data[i] 108 | are_same.append(0 == torch.sum(sample_1 - sample_2)) 109 | 110 | if active: 111 | self.assertTrue(not all(are_same)) 112 | else: 113 | self.assertTrue(all(are_same)) 114 | 115 | def test_single_process_dataloader(self): 116 | with self.subTest(split='train'): 117 | self._check_dataloader(self.data.train_data, num_workers=0) 118 | with self.subTest(split='test'): 119 | self._check_dataloader(self.data.test_data, num_workers=0) 120 | 121 | def test_multi_process_dataloader(self): 122 | with self.subTest(split='train'): 123 | self._check_dataloader(self.data.train_data, num_workers=2) 124 | with self.subTest(split='test'): 125 | self._check_dataloader(self.data.test_data, num_workers=2) 126 | 127 | def _check_dataloader(self, data, num_workers): 128 | loader = DataLoader(data, batch_size=4, num_workers=num_workers) 129 | for _ in loader: 130 | pass 131 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | import dataset 6 | from tests import templates 7 | 8 | class TestMNIST(unittest.TestCase, templates.DatasetTestsMixin): 9 | def setUp(self): 10 | self.data = dataset.MyMNIST() 11 | self.data_shape = torch.Size((1, 32, 32)) 12 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | import model 6 | from tests import templates 7 | 8 | 9 | class TestCNNVAE(unittest.TestCase, templates.ModelTestsMixin): 10 | def setUp(self): 11 | self.test_inputs = torch.randn(4, 1, 32, 32) 12 | self.net = model.CNNVAE(input_shape=(1, 32, 32), bottleneck_dim=16) 13 | 14 | 15 | class TestMLPVAE(unittest.TestCase, templates.ModelTestsMixin): 16 | def setUp(self): 17 | self.test_inputs = torch.randn(4, 1, 32, 32) 18 | self.net = model.MLPVAE(input_shape=(1, 32, 32), bottleneck_dim=16) 19 | -------------------------------------------------------------------------------- /tests/test_trainer.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tempfile 3 | import shutil 4 | from unittest import mock 5 | 6 | import numpy as np 7 | import scipy.stats 8 | import torch 9 | from torch.utils.data import Subset 10 | 11 | import model 12 | import dataset 13 | import trainer 14 | 15 | 16 | class TestTrainer(unittest.TestCase): 17 | def setUp(self): 18 | seed = 42 19 | torch.manual_seed(seed) 20 | np.random.seed(seed) 21 | 22 | # Build dataset with only one batch 23 | self.data = dataset.MyMNIST() 24 | self.data.train_data = Subset(self.data.train_data, range(4)) 25 | self.data.test_data = Subset(self.data.train_data, range(4)) 26 | vae = model.CNNVAE(self.data.train_data[0][0].shape, bottleneck_dim=10) 27 | optim = torch.optim.Adam(vae.parameters()) 28 | self.log_dir = tempfile.mkdtemp() 29 | self.vae_trainer = trainer.Trainer(vae, self.data, optim, 30 | batch_size=4, 31 | device='cpu', 32 | log_dir=self.log_dir, 33 | num_generated_images=1) 34 | 35 | def tearDown(self): 36 | shutil.rmtree(self.log_dir) 37 | 38 | @torch.no_grad() 39 | def test_kl_divergence(self): 40 | mu = np.random.randn(10) * 0.25 41 | sigma = np.random.randn(10) * 0.1 + 1. 42 | standard_normal_samples = np.random.randn(100000, 10) 43 | transformed_normal_sample = standard_normal_samples * sigma + mu 44 | 45 | # Calculate empirical pdfs for both distributions 46 | bins = 1000 47 | bin_range = [-2, 2] 48 | expected_kl_div = 0 49 | for i in range(10): 50 | standard_normal_dist, _ = np.histogram(standard_normal_samples[:, i], bins, bin_range) 51 | transformed_normal_dist, _ = np.histogram(transformed_normal_sample[:, i], bins, bin_range) 52 | expected_kl_div += scipy.stats.entropy(transformed_normal_dist, standard_normal_dist) 53 | 54 | actual_kl_div = self.vae_trainer._kl_divergence(torch.tensor(sigma).log(), torch.tensor(mu)) 55 | 56 | self.assertAlmostEqual(expected_kl_div, actual_kl_div.numpy(), delta=0.05) 57 | 58 | def test_overfit_on_one_batch(self): 59 | # Overfit on single batch 60 | self.vae_trainer.train(500) 61 | 62 | # Overfitting a VAE is hard, so we do not choose 0. as a goal 63 | # 30 sum of squared errors would be a deviation of ~0,04 per pixel given a really small KL-Div 64 | self.assertGreaterEqual(30, self.vae_trainer.eval()) 65 | 66 | def test_logging(self): 67 | with mock.patch.object(self.vae_trainer.summary, 'add_scalar') as add_scalar_mock: 68 | self.vae_trainer.train(1) 69 | 70 | expected_calls = [mock.call('train/recon_loss', mock.ANY, 0), 71 | mock.call('train/kl_div_loss', mock.ANY, 0), 72 | mock.call('train/loss', mock.ANY, 0), 73 | mock.call('test/loss', mock.ANY, 0)] 74 | add_scalar_mock.assert_has_calls(expected_calls) 75 | --------------------------------------------------------------------------------