├── .gitignore
├── .vscode
└── settings.json
├── Interpolations
└── linear_interpolation_dims=128.png
├── LICENSE.md
├── README.md
├── Reconstructions
├── recons_epoch_10.png
├── recons_epoch_20_128dims.png
└── recons_epoch_20_256dims.png
├── Samples
├── generated_samples_epoch_10.png
├── generated_samples_epoch_20_128dims.png
└── generated_samples_epoch_20_256dims.png
├── base
├── __init__.py
├── base_data_loader.py
├── base_model.py
└── base_trainer.py
├── config.json
├── data_loader
└── data_loaders.py
├── logger
├── __init__.py
├── logger.py
├── logger_config.json
└── visualization.py
├── model
├── loss.py
├── metric.py
├── model.py
└── types_.py
├── parse_config.py
├── requirements.txt
├── test.py
├── train.py
├── trainer
├── __init__.py
└── trainer.py
├── traversal_test.ipynb
└── utils
├── __init__.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 | # input data, saved log, checkpoints
104 | data/
105 | input/
106 | saved/
107 | datasets/
108 |
109 | # editor, os cache directory
110 | .vscode/
111 | .idea/
112 | __MACOSX/
113 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {}
--------------------------------------------------------------------------------
/Interpolations/linear_interpolation_dims=128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/julian-8897/Conv-VAE-PyTorch/d86013578435b9208468a482cb14e7a3c79d7510/Interpolations/linear_interpolation_dims=128.png
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Julian Chan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | Image Generation and Reconstruction with Convolutional Variational Autoencoder (VAE) in PyTorch
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | ## Implementation Details
16 |
17 | A PyTorch implementation of the standard Variational Autoencoder (VAE). The amortized inference model (encoder) is parameterized by a convolutional network, while the generative model (decoder) is parameterized by a transposed convolutional network. The choice of the approximate posterior is a fully-factorized gaussian distribution with diagonal covariance.
18 |
19 | This implementation supports model training on the [CelebA dataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). This project serves as a proof of concept, hence the original images (178 x 218) are scaled and cropped to (64 x 64) images in order to speed up the training process. For ease of access, the zip file which contains the dataset can be downloaded from: https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip.
20 |
21 | The VAE model was evaluated on several downstream tasks, such as image reconstruction and image generation. Some sample results can be found in the [Results](https://github.com/julian-8897/Vanilla-VAE-PyTorch/blob/master/README.md#--Results) section.
22 |
23 |
24 | [](https://learnopencv.com/wp-content/uploads/2020/11/vae-diagram-1-2048x1126.jpg)
25 | *Figure 1: Visual Representation of VAE. Image source: [LearnOpenCV](https://learnopencv.com/wp-content/uploads/2020/11/vae-diagram-1-2048x1126.jpg)*
26 |
27 |
28 | ## Requirements
29 |
30 | - Python >= 3.9
31 | - PyTorch >= 1.9
32 |
33 | ## Installation Guide
34 |
35 | ```
36 | $ git clone https://github.com/julian-8897/Conv-VAE-PyTorch.git
37 | $ cd Vanilla-VAE-PyTorch
38 | $ pip install -r requirements.txt
39 | ```
40 |
41 | ## Usage
42 |
43 | ### Training
44 |
45 | To train the model, please modify the `config.json` configuration file, and run:
46 |
47 | ```
48 | python train.py --config config.json
49 | ```
50 |
51 | ### Resuming Training
52 |
53 | To resume training of the model from a checkpoint, you can run the following command:
54 |
55 | ```
56 | python train.py --resume path/to/checkpoint
57 | ```
58 |
59 | ### Testing
60 |
61 | To test the model, you can run the following command:
62 |
63 | ```
64 | python test.py --resume path/to/checkpoint
65 | ```
66 |
67 | Generated plots are stored in the 'Reconstructions' and 'Samples' folders.
68 |
69 | ---
70 |
71 |
72 | Results
73 |
74 |
75 | ## 128 Latent Dimensions
76 |
77 | | Reconstructed Samples | Generated Samples |
78 | | --------------------- | ----------------- |
79 | | ![][1] | ![][2] |
80 |
81 | ## 256 Latent Dimensions
82 |
83 | | Reconstructed Samples | Generated Samples |
84 | | --------------------- | ----------------- |
85 | | ![][3] | ![][4] |
86 |
87 | [1]: https://github.com/julian-8897/Vanilla-VAE-PyTorch/blob/master/Reconstructions/recons_epoch_20_128dims.png
88 | [2]: https://github.com/julian-8897/Vanilla-VAE-PyTorch/blob/master/Samples/generated_samples_epoch_20_128dims.png
89 | [3]: https://github.com/julian-8897/Vanilla-VAE-PyTorch/blob/master/Reconstructions/recons_epoch_20_256dims.png
90 | [4]: https://github.com/julian-8897/Vanilla-VAE-PyTorch/blob/master/Samples/generated_samples_epoch_20_256dims.png
91 |
92 | ## References
93 |
94 | 1. Original VAE paper "Auto-Encoding Variational Bayes" by Kingma & Welling:
95 | https://arxiv.org/abs/1312.6114
96 |
97 | 2. Various implementations of VAEs in PyTorch:
98 | https://github.com/AntixK/PyTorch-VAE
99 |
100 | 3. PyTorch template used in this project:
101 | https://github.com/victoresque/pytorch-template
102 |
103 | 4. A comprehensive introduction to VAEs:
104 | https://arxiv.org/pdf/1906.02691.pdf
105 |
--------------------------------------------------------------------------------
/Reconstructions/recons_epoch_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/julian-8897/Conv-VAE-PyTorch/d86013578435b9208468a482cb14e7a3c79d7510/Reconstructions/recons_epoch_10.png
--------------------------------------------------------------------------------
/Reconstructions/recons_epoch_20_128dims.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/julian-8897/Conv-VAE-PyTorch/d86013578435b9208468a482cb14e7a3c79d7510/Reconstructions/recons_epoch_20_128dims.png
--------------------------------------------------------------------------------
/Reconstructions/recons_epoch_20_256dims.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/julian-8897/Conv-VAE-PyTorch/d86013578435b9208468a482cb14e7a3c79d7510/Reconstructions/recons_epoch_20_256dims.png
--------------------------------------------------------------------------------
/Samples/generated_samples_epoch_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/julian-8897/Conv-VAE-PyTorch/d86013578435b9208468a482cb14e7a3c79d7510/Samples/generated_samples_epoch_10.png
--------------------------------------------------------------------------------
/Samples/generated_samples_epoch_20_128dims.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/julian-8897/Conv-VAE-PyTorch/d86013578435b9208468a482cb14e7a3c79d7510/Samples/generated_samples_epoch_20_128dims.png
--------------------------------------------------------------------------------
/Samples/generated_samples_epoch_20_256dims.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/julian-8897/Conv-VAE-PyTorch/d86013578435b9208468a482cb14e7a3c79d7510/Samples/generated_samples_epoch_20_256dims.png
--------------------------------------------------------------------------------
/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_data_loader import *
2 | from .base_model import *
3 | from .base_trainer import *
4 |
--------------------------------------------------------------------------------
/base/base_data_loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import DataLoader
3 | from torch.utils.data.dataloader import default_collate
4 | from torch.utils.data.sampler import SubsetRandomSampler
5 |
6 |
7 | class BaseDataLoader(DataLoader):
8 | """
9 | Base class for all data loaders
10 | """
11 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
12 | self.validation_split = validation_split
13 | self.shuffle = shuffle
14 |
15 | self.batch_idx = 0
16 | self.n_samples = len(dataset)
17 |
18 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
19 |
20 | self.init_kwargs = {
21 | 'dataset': dataset,
22 | 'batch_size': batch_size,
23 | 'shuffle': self.shuffle,
24 | 'collate_fn': collate_fn,
25 | 'num_workers': num_workers
26 | }
27 | super().__init__(sampler=self.sampler, **self.init_kwargs)
28 |
29 | def _split_sampler(self, split):
30 | if split == 0.0:
31 | return None, None
32 |
33 | idx_full = np.arange(self.n_samples)
34 |
35 | np.random.seed(0)
36 | np.random.shuffle(idx_full)
37 |
38 | if isinstance(split, int):
39 | assert split > 0
40 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
41 | len_valid = split
42 | else:
43 | len_valid = int(self.n_samples * split)
44 |
45 | valid_idx = idx_full[0:len_valid]
46 | train_idx = np.delete(idx_full, np.arange(0, len_valid))
47 |
48 | train_sampler = SubsetRandomSampler(train_idx)
49 | valid_sampler = SubsetRandomSampler(valid_idx)
50 |
51 | # turn off shuffle option which is mutually exclusive with sampler
52 | self.shuffle = False
53 | self.n_samples = len(train_idx)
54 |
55 | return train_sampler, valid_sampler
56 |
57 | def split_validation(self):
58 | if self.valid_sampler is None:
59 | return None
60 | else:
61 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
62 |
--------------------------------------------------------------------------------
/base/base_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | from abc import abstractmethod
4 |
5 |
6 | class BaseModel(nn.Module):
7 | """
8 | Base class for all models
9 | """
10 | @abstractmethod
11 | def forward(self, *inputs):
12 | """
13 | Forward pass logic
14 |
15 | :return: Model output
16 | """
17 | raise NotImplementedError
18 |
19 | def __str__(self):
20 | """
21 | Model prints with number of trainable parameters
22 | """
23 | model_parameters = filter(lambda p: p.requires_grad, self.parameters())
24 | params = sum([np.prod(p.size()) for p in model_parameters])
25 | return super().__str__() + '\nTrainable parameters: {}'.format(params)
26 |
--------------------------------------------------------------------------------
/base/base_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from abc import abstractmethod
3 | from numpy import inf
4 | from logger import TensorboardWriter
5 |
6 |
7 | class BaseTrainer:
8 | """
9 | Base class for all trainers
10 | """
11 |
12 | def __init__(self, model, criterion, optimizer, config):
13 | self.config = config
14 | self.logger = config.get_logger(
15 | 'trainer', config['trainer']['verbosity'])
16 |
17 | self.model = model
18 | self.criterion = criterion
19 | # self.metric_ftns = metric_ftns
20 | self.optimizer = optimizer
21 |
22 | cfg_trainer = config['trainer']
23 | self.epochs = cfg_trainer['epochs']
24 | self.save_period = cfg_trainer['save_period']
25 | self.monitor = cfg_trainer.get('monitor', 'off')
26 |
27 | # configuration to monitor model performance and save best
28 | if self.monitor == 'off':
29 | self.mnt_mode = 'off'
30 | self.mnt_best = 0
31 | else:
32 | self.mnt_mode, self.mnt_metric = self.monitor.split()
33 | assert self.mnt_mode in ['min', 'max']
34 |
35 | self.mnt_best = inf if self.mnt_mode == 'min' else -inf
36 | self.early_stop = cfg_trainer.get('early_stop', inf)
37 | if self.early_stop <= 0:
38 | self.early_stop = inf
39 |
40 | self.start_epoch = 1
41 |
42 | self.checkpoint_dir = config.save_dir
43 |
44 | # setup visualization writer instance
45 | self.writer = TensorboardWriter(
46 | config.log_dir, self.logger, cfg_trainer['tensorboard'])
47 |
48 | if config.resume is not None:
49 | self._resume_checkpoint(config.resume)
50 |
51 | @abstractmethod
52 | def _train_epoch(self, epoch):
53 | """
54 | Training logic for an epoch
55 |
56 | :param epoch: Current epoch number
57 | """
58 | raise NotImplementedError
59 |
60 | def train(self):
61 | """
62 | Full training logic
63 | """
64 | not_improved_count = 0
65 | for epoch in range(self.start_epoch, self.epochs + 1):
66 | result = self._train_epoch(epoch)
67 |
68 | # save logged informations into log dict
69 | log = {'epoch': epoch}
70 | log.update(result)
71 |
72 | # print logged informations to the screen
73 | for key, value in log.items():
74 | self.logger.info(' {:15s}: {}'.format(str(key), value))
75 |
76 | # evaluate model performance according to configured metric, save best checkpoint as model_best
77 | best = False
78 | if self.mnt_mode != 'off':
79 | try:
80 | # check whether model performance improved or not, according to specified metric(mnt_metric)
81 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
82 | (self.mnt_mode ==
83 | 'max' and log[self.mnt_metric] >= self.mnt_best)
84 | except KeyError:
85 | self.logger.warning("Warning: Metric '{}' is not found. "
86 | "Model performance monitoring is disabled.".format(self.mnt_metric))
87 | self.mnt_mode = 'off'
88 | improved = False
89 |
90 | if improved:
91 | self.mnt_best = log[self.mnt_metric]
92 | not_improved_count = 0
93 | best = True
94 | else:
95 | not_improved_count += 1
96 |
97 | if not_improved_count > self.early_stop:
98 | self.logger.info("Validation performance didn\'t improve for {} epochs. "
99 | "Training stops.".format(self.early_stop))
100 | break
101 |
102 | if epoch % self.save_period == 0:
103 | self._save_checkpoint(epoch, save_best=best)
104 |
105 | def _save_checkpoint(self, epoch, save_best=False):
106 | """
107 | Saving checkpoints
108 |
109 | :param epoch: current epoch number
110 | :param log: logging information of the epoch
111 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
112 | """
113 | arch = type(self.model).__name__
114 | state = {
115 | 'arch': arch,
116 | 'epoch': epoch,
117 | 'state_dict': self.model.state_dict(),
118 | 'optimizer': self.optimizer.state_dict(),
119 | 'monitor_best': self.mnt_best,
120 | 'config': self.config
121 | }
122 | filename = str(self.checkpoint_dir /
123 | 'checkpoint-epoch{}.pth'.format(epoch))
124 | torch.save(state, filename)
125 | self.logger.info("Saving checkpoint: {} ...".format(filename))
126 | if save_best:
127 | best_path = str(self.checkpoint_dir / 'model_best.pth')
128 | torch.save(state, best_path)
129 | self.logger.info("Saving current best: model_best.pth ...")
130 |
131 | def _resume_checkpoint(self, resume_path):
132 | """
133 | Resume from saved checkpoints
134 |
135 | :param resume_path: Checkpoint path to be resumed
136 | """
137 | resume_path = str(resume_path)
138 | self.logger.info("Loading checkpoint: {} ...".format(resume_path))
139 | checkpoint = torch.load(resume_path)
140 | self.start_epoch = checkpoint['epoch'] + 1
141 | self.mnt_best = checkpoint['monitor_best']
142 |
143 | # load architecture params from checkpoint.
144 | if checkpoint['config']['arch'] != self.config['arch']:
145 | self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
146 | "checkpoint. This may yield an exception while state_dict is being loaded.")
147 | self.model.load_state_dict(checkpoint['state_dict'])
148 |
149 | # load optimizer state from checkpoint only when optimizer type is not changed.
150 | if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
151 | self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
152 | "Optimizer parameters not being resumed.")
153 | else:
154 | self.optimizer.load_state_dict(checkpoint['optimizer'])
155 |
156 | self.logger.info(
157 | "Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
158 |
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "VAE_CelebA",
3 | "n_gpu": 1,
4 | "arch": {
5 | "type": "VanillaVAE",
6 | "args": {
7 | "in_channels": 3,
8 | "latent_dims": 128,
9 | "flow": false
10 | }
11 | },
12 | "data_loader": {
13 | "type": "CelebDataLoader",
14 | "args": {
15 | "data_dir": "data/",
16 | "batch_size": 64,
17 | "shuffle": true,
18 | "validation_split": 0.2,
19 | "num_workers": 2
20 | }
21 | },
22 | "optimizer": {
23 | "type": "Adam",
24 | "args": {
25 | "lr": 0.005,
26 | "weight_decay": 0.0,
27 | "amsgrad": true
28 | }
29 | },
30 | "loss": "elbo_loss",
31 | "metrics": [],
32 | "lr_scheduler": {
33 | "type": "StepLR",
34 | "args": {
35 | "step_size": 50,
36 | "gamma": 0.1
37 | }
38 | },
39 | "trainer": {
40 | "epochs": 20,
41 | "save_dir": "saved/",
42 | "save_period": 1,
43 | "verbosity": 2,
44 | "monitor": "min val_loss",
45 | "early_stop": 10,
46 | "tensorboard": true
47 | }
48 | }
--------------------------------------------------------------------------------
/data_loader/data_loaders.py:
--------------------------------------------------------------------------------
1 | from torchvision import datasets, transforms
2 | from base import BaseDataLoader
3 |
4 |
5 | class MnistDataLoader(BaseDataLoader):
6 | """
7 | MNIST data loading demo using BaseDataLoader
8 | """
9 |
10 | def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True):
11 | trsfm = transforms.Compose([
12 | transforms.ToTensor(),
13 | transforms.Normalize((0.1307,), (0.3081,))
14 | ])
15 | self.data_dir = data_dir
16 | self.dataset = datasets.MNIST(
17 | self.data_dir, train=training, download=True, transform=trsfm)
18 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)
19 |
20 |
21 | class CelebDataLoader(BaseDataLoader):
22 | """
23 | CelebA data loading
24 | Download and extract:
25 | https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip
26 | """
27 |
28 | def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, image_size=64):
29 | transform = transforms.Compose([
30 | transforms.Resize(image_size),
31 | transforms.CenterCrop(image_size),
32 | transforms.ToTensor(),
33 | ])
34 | self.data_dir = data_dir
35 | self.dataset = datasets.ImageFolder(
36 | self.data_dir, transform=transform)
37 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)
38 |
--------------------------------------------------------------------------------
/logger/__init__.py:
--------------------------------------------------------------------------------
1 | from .logger import *
2 | from .visualization import *
--------------------------------------------------------------------------------
/logger/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.config
3 | from pathlib import Path
4 | from utils import read_json
5 |
6 |
7 | def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO):
8 | """
9 | Setup logging configuration
10 | """
11 | log_config = Path(log_config)
12 | if log_config.is_file():
13 | config = read_json(log_config)
14 | # modify logging paths based on run config
15 | for _, handler in config['handlers'].items():
16 | if 'filename' in handler:
17 | handler['filename'] = str(save_dir / handler['filename'])
18 |
19 | logging.config.dictConfig(config)
20 | else:
21 | print("Warning: logging configuration file is not found in {}.".format(log_config))
22 | logging.basicConfig(level=default_level)
23 |
--------------------------------------------------------------------------------
/logger/logger_config.json:
--------------------------------------------------------------------------------
1 |
2 | {
3 | "version": 1,
4 | "disable_existing_loggers": false,
5 | "formatters": {
6 | "simple": {"format": "%(message)s"},
7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
8 | },
9 | "handlers": {
10 | "console": {
11 | "class": "logging.StreamHandler",
12 | "level": "DEBUG",
13 | "formatter": "simple",
14 | "stream": "ext://sys.stdout"
15 | },
16 | "info_file_handler": {
17 | "class": "logging.handlers.RotatingFileHandler",
18 | "level": "INFO",
19 | "formatter": "datetime",
20 | "filename": "info.log",
21 | "maxBytes": 10485760,
22 | "backupCount": 20, "encoding": "utf8"
23 | }
24 | },
25 | "root": {
26 | "level": "INFO",
27 | "handlers": [
28 | "console",
29 | "info_file_handler"
30 | ]
31 | }
32 | }
--------------------------------------------------------------------------------
/logger/visualization.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from datetime import datetime
3 |
4 |
5 | class TensorboardWriter():
6 | def __init__(self, log_dir, logger, enabled):
7 | self.writer = None
8 | self.selected_module = ""
9 |
10 | if enabled:
11 | log_dir = str(log_dir)
12 |
13 | # Retrieve vizualization writer.
14 | succeeded = False
15 | for module in ["torch.utils.tensorboard", "tensorboardX"]:
16 | try:
17 | self.writer = importlib.import_module(module).SummaryWriter(log_dir)
18 | succeeded = True
19 | break
20 | except ImportError:
21 | succeeded = False
22 | self.selected_module = module
23 |
24 | if not succeeded:
25 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
26 | "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \
27 | "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file."
28 | logger.warning(message)
29 |
30 | self.step = 0
31 | self.mode = ''
32 |
33 | self.tb_writer_ftns = {
34 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
35 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
36 | }
37 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
38 | self.timer = datetime.now()
39 |
40 | def set_step(self, step, mode='train'):
41 | self.mode = mode
42 | self.step = step
43 | if step == 0:
44 | self.timer = datetime.now()
45 | else:
46 | duration = datetime.now() - self.timer
47 | self.add_scalar('steps_per_sec', 1 / duration.total_seconds())
48 | self.timer = datetime.now()
49 |
50 | def __getattr__(self, name):
51 | """
52 | If visualization is configured to use:
53 | return add_data() methods of tensorboard with additional information (step, tag) added.
54 | Otherwise:
55 | return a blank function handle that does nothing
56 | """
57 | if name in self.tb_writer_ftns:
58 | add_data = getattr(self.writer, name, None)
59 |
60 | def wrapper(tag, data, *args, **kwargs):
61 | if add_data is not None:
62 | # add mode(train/valid) tag
63 | if name not in self.tag_mode_exceptions:
64 | tag = '{}/{}'.format(tag, self.mode)
65 | add_data(tag, data, self.step, *args, **kwargs)
66 | return wrapper
67 | else:
68 | # default action for returning methods defined in this class, set_step() for instance.
69 | try:
70 | attr = object.__getattr__(name)
71 | except AttributeError:
72 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
73 | return attr
74 |
--------------------------------------------------------------------------------
/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from torch import nn
3 | import torch
4 |
5 |
6 | def elbo_loss(recon_x, x, mu, logvar, beta=1):
7 | """
8 | ELBO Optimization objective for gaussian posterior
9 | (reconstruction term + regularization term)
10 | """
11 | reconstruction_function = nn.MSELoss(reduction='sum')
12 | MSE = reconstruction_function(recon_x, x)
13 |
14 | # https://arxiv.org/abs/1312.6114 (Appendix B)
15 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
16 |
17 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
18 | KLD = torch.sum(KLD_element).mul_(-0.5)
19 |
20 | return MSE + beta*KLD
21 |
22 |
23 | def elbo_loss_flow(recon_x, x, mu, logvar, log_det):
24 | """
25 | ELBO Optimization objective for gaussian posterior
26 | (reconstruction term + regularization term)
27 | """
28 | reconstruction_function = nn.MSELoss(reduction='sum')
29 | MSE = reconstruction_function(recon_x, x)
30 |
31 | # https://arxiv.org/abs/1312.6114 (Appendix B)
32 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
33 |
34 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
35 | KLD = torch.sum(KLD_element).mul_(-0.5)
36 |
37 | return (MSE + KLD - log_det).mean()
38 |
--------------------------------------------------------------------------------
/model/metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def metric():
5 | """
6 | Implement evaluation metric here if needed
7 | """
8 | pass
9 |
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from base import BaseModel
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from .types_ import *
6 |
7 |
8 | class PlanarFlow(nn.Module):
9 | def __init__(self, dim):
10 | """Instantiates one step of planar flow.
11 | Args:
12 | dim: input dimensionality.
13 | """
14 | super(PlanarFlow, self).__init__()
15 |
16 | self.u = nn.Parameter(torch.randn(1, dim))
17 | self.w = nn.Parameter(torch.randn(1, dim))
18 | self.b = nn.Parameter(torch.randn(1))
19 |
20 | def forward(self, x):
21 | """Forward pass.
22 | Args:
23 | x: input tensor (B x D).
24 | Returns:
25 | transformed x and log-determinant of Jacobian.
26 | """
27 | def m(x):
28 | return F.softplus(x) - 1.
29 |
30 | def h(x):
31 | return torch.tanh(x)
32 |
33 | def h_prime(x):
34 | return 1. - h(x)**2
35 |
36 | inner = (self.w * self.u).sum()
37 | u = self.u + (m(inner) - inner) * self.w / self.w.norm()**2
38 | activation = (self.w * x).sum(dim=1, keepdim=True) + self.b
39 | x = x + u * h(activation)
40 | psi = h_prime(activation) * self.w
41 | log_det = torch.log(torch.abs(1. + (u * psi).sum(dim=1, keepdim=True)))
42 |
43 | return x, log_det
44 |
45 |
46 | class Flow(nn.Module):
47 | def __init__(self, dim, type, length):
48 | """Instantiates a chain of flows.
49 | Args:
50 | dim: input dimensionality.
51 | type: type of flow.
52 | length: length of flow.
53 | """
54 | super(Flow, self).__init__()
55 |
56 | if type == 'planar':
57 | self.flow = nn.ModuleList([PlanarFlow(dim) for _ in range(length)])
58 | # elif type == 'radial':
59 | # self.flow = nn.ModuleList([RadialFlow(dim) for _ in range(length)])
60 | # elif type == 'householder':
61 | # self.flow = nn.ModuleList([HouseholderFlow(dim) for _ in range(length)])
62 | # elif type == 'nice':
63 | # self.flow = nn.ModuleList([NiceFlow(dim, i//2, i==(length-1)) for i in range(length)])
64 | else:
65 | self.flow = nn.ModuleList([])
66 |
67 | def forward(self, x):
68 | """Forward pass.
69 | Args:
70 | x: input tensor (B x D).
71 | Returns:
72 | transformed x and log-determinant of Jacobian.
73 | """
74 | [B, _] = list(x.size())
75 | # log_det = torch.zeros(B, 1).cuda()
76 | log_det = torch.zeros(B, 1)
77 | for i in range(len(self.flow)):
78 | x, inc = self.flow[i](x)
79 | log_det = log_det + inc
80 |
81 | return x, log_det
82 |
83 |
84 | class VanillaVAE(BaseModel):
85 |
86 | def __init__(self,
87 | in_channels: int,
88 | latent_dims: int,
89 | hidden_dims: List[int] = None,
90 | flow_check=False,
91 | **kwargs) -> None:
92 | """Instantiates the VAE model
93 |
94 | Params:
95 | in_channels (int): Number of input channels
96 | latent_dims (int): Size of latent dimensions
97 | hidden_dims (List[int]): List of hidden dimensions
98 | """
99 | super(VanillaVAE, self).__init__()
100 | self.latent_dim = latent_dims
101 | self.flow_check = flow_check
102 |
103 | if self.flow_check:
104 | self.flow = Flow(self.latent_dim, 'planar', 16)
105 |
106 | modules = []
107 | if hidden_dims is None:
108 | hidden_dims = [32, 64, 128, 256, 512]
109 |
110 | # Build Encoder
111 | for h_dim in hidden_dims:
112 | modules.append(
113 | nn.Sequential(
114 | nn.Conv2d(in_channels, out_channels=h_dim,
115 | kernel_size=3, stride=2, padding=1),
116 | nn.BatchNorm2d(h_dim),
117 | nn.LeakyReLU())
118 | )
119 | in_channels = h_dim
120 |
121 | self.encoder = nn.Sequential(*modules)
122 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dims)
123 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dims)
124 |
125 | # Build Decoder
126 | modules = []
127 | self.decoder_input = nn.Linear(latent_dims, hidden_dims[-1] * 4)
128 |
129 | hidden_dims.reverse()
130 |
131 | for i in range(len(hidden_dims) - 1):
132 | modules.append(
133 | nn.Sequential(
134 | nn.ConvTranspose2d(hidden_dims[i],
135 | hidden_dims[i + 1],
136 | kernel_size=3,
137 | stride=2,
138 | padding=1,
139 | output_padding=1),
140 | nn.BatchNorm2d(hidden_dims[i + 1]),
141 | nn.LeakyReLU())
142 | )
143 |
144 | self.decoder = nn.Sequential(*modules)
145 |
146 | self.final_layer = nn.Sequential(
147 | nn.ConvTranspose2d(hidden_dims[-1],
148 | hidden_dims[-1],
149 | kernel_size=3,
150 | stride=2,
151 | padding=1,
152 | output_padding=1),
153 | nn.BatchNorm2d(hidden_dims[-1]),
154 | nn.LeakyReLU(),
155 | nn.Conv2d(hidden_dims[-1], out_channels=3,
156 | kernel_size=3, padding=1),
157 | nn.Tanh())
158 |
159 | def encode(self, input: Tensor) -> List[Tensor]:
160 | """
161 | Encodes the input by passing through the convolutional network
162 | and outputs the latent variables.
163 |
164 | Params:
165 | input (Tensor): Input tensor [N x C x H x W]
166 |
167 | Returns:
168 | mu (Tensor) and log_var (Tensor) of latent variables
169 | """
170 |
171 | result = self.encoder(input)
172 | result = torch.flatten(result, start_dim=1)
173 |
174 | # Split the result into mu and var components
175 | # of the latent Gaussian distribution
176 | mu = self.fc_mu(result)
177 | log_var = self.fc_var(result)
178 |
179 | if self.flow_check:
180 | z, log_det = self.reparameterize(mu, log_var)
181 | return mu, log_var, z, log_det
182 |
183 | else:
184 | z = self.reparameterize(mu, log_var)
185 | return mu, log_var, z
186 |
187 | def decode(self, z: Tensor) -> Tensor:
188 | """
189 | Maps the given latent variables
190 | onto the image space.
191 |
192 | Params:
193 | z (Tensor): Latent variable [B x D]
194 |
195 | Returns:
196 | result (Tensor) [B x C x H x W]
197 | """
198 |
199 | result = self.decoder_input(z)
200 | result = result.view(-1, 512, 2, 2)
201 | result = self.decoder(result)
202 | result = self.final_layer(result)
203 |
204 | return result
205 |
206 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
207 | """
208 | Reparameterization trick to sample from N(mu, var) from
209 | N(0,1)
210 |
211 | Params:
212 | mu (Tensor): Mean of Gaussian latent variables [B x D]
213 | logvar (Tensor): log-Variance of Gaussian latent variables [B x D]
214 |
215 | Returns:
216 | z (Tensor) [B x D]
217 | """
218 |
219 | std = torch.exp(0.5 * logvar)
220 | eps = torch.randn_like(std)
221 | z = eps.mul(std).add_(mu)
222 |
223 | if self.flow_check:
224 | return self.flow(z)
225 |
226 | else:
227 | return z
228 |
229 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
230 |
231 | if self.flow_check:
232 | mu, log_var, z, log_det = self.encode(input)
233 |
234 | return self.decode(z), mu, log_var, log_det
235 |
236 | else:
237 | mu, log_var, z = self.encode(input)
238 |
239 | return self.decode(z), mu, log_var
240 |
241 | def sample(self,
242 | num_samples: int,
243 | current_device: int, **kwargs) -> Tensor:
244 | """
245 | Samples from the latent space and return the corresponding
246 | image space map.
247 |
248 | Params:
249 | num_samples (Int): Number of samples
250 | current_device (Int): Device to run the model
251 |
252 | Returns:
253 | samples (Tensor)
254 | """
255 |
256 | z = torch.randn(num_samples,
257 | self.latent_dim)
258 | z = z.to(current_device)
259 | samples = self.decode(z)
260 |
261 | return samples
262 |
263 | def generate(self, x: Tensor, **kwargs) -> Tensor:
264 | """
265 | Given an input image x, returns the reconstructed image
266 |
267 | Params:
268 | x (Tensor): input image Tensor [B x C x H x W]
269 |
270 | Returns:
271 | (Tensor) [B x C x H x W]
272 | """
273 |
274 | return self.forward(x)[0]
275 |
--------------------------------------------------------------------------------
/model/types_.py:
--------------------------------------------------------------------------------
1 | from typing import List, Callable, Union, Any, TypeVar, Tuple
2 | # from torch import tensor as Tensor
3 |
4 | Tensor = TypeVar('torch.tensor')
5 |
--------------------------------------------------------------------------------
/parse_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from pathlib import Path
4 | from functools import reduce, partial
5 | from operator import getitem
6 | from datetime import datetime
7 | from logger import setup_logging
8 | from utils import read_json, write_json
9 |
10 |
11 | class ConfigParser:
12 | def __init__(self, config, resume=None, modification=None, run_id=None):
13 | """
14 | class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving
15 | and logging module.
16 | :param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example.
17 | :param resume: String, path to the checkpoint being loaded.
18 | :param modification: Dict keychain:value, specifying position values to be replaced from config dict.
19 | :param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default
20 | """
21 | # load config file and apply modification
22 | self._config = _update_config(config, modification)
23 | self.resume = resume
24 |
25 | # set save_dir where trained model and log will be saved.
26 | save_dir = Path(self.config['trainer']['save_dir'])
27 |
28 | exper_name = self.config['name']
29 | if run_id is None: # use timestamp as default run-id
30 | run_id = datetime.now().strftime(r'%m%d_%H%M%S')
31 | self._save_dir = save_dir / 'models' / exper_name / run_id
32 | self._log_dir = save_dir / 'log' / exper_name / run_id
33 |
34 | # make directory for saving checkpoints and log.
35 | exist_ok = run_id == ''
36 | self.save_dir.mkdir(parents=True, exist_ok=exist_ok)
37 | self.log_dir.mkdir(parents=True, exist_ok=exist_ok)
38 |
39 | # save updated config file to the checkpoint dir
40 | write_json(self.config, self.save_dir / 'config.json')
41 |
42 | # configure logging module
43 | setup_logging(self.log_dir)
44 | self.log_levels = {
45 | 0: logging.WARNING,
46 | 1: logging.INFO,
47 | 2: logging.DEBUG
48 | }
49 |
50 | @classmethod
51 | def from_args(cls, args, options=''):
52 | """
53 | Initialize this class from some cli arguments. Used in train, test.
54 | """
55 | for opt in options:
56 | args.add_argument(*opt.flags, default=None, type=opt.type)
57 | if not isinstance(args, tuple):
58 | args = args.parse_args()
59 |
60 | if args.device is not None:
61 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device
62 | if args.resume is not None:
63 | resume = Path(args.resume)
64 | cfg_fname = resume.parent / 'config.json'
65 | else:
66 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
67 | assert args.config is not None, msg_no_cfg
68 | resume = None
69 | cfg_fname = Path(args.config)
70 |
71 | config = read_json(cfg_fname)
72 | if args.config and resume:
73 | # update new config for fine-tuning
74 | config.update(read_json(args.config))
75 |
76 | # parse custom cli options into dictionary
77 | modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options}
78 | return cls(config, resume, modification)
79 |
80 | def init_obj(self, name, module, *args, **kwargs):
81 | """
82 | Finds a function handle with the name given as 'type' in config, and returns the
83 | instance initialized with corresponding arguments given.
84 |
85 | `object = config.init_obj('name', module, a, b=1)`
86 | is equivalent to
87 | `object = module.name(a, b=1)`
88 | """
89 | module_name = self[name]['type']
90 | module_args = dict(self[name]['args'])
91 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
92 | module_args.update(kwargs)
93 | return getattr(module, module_name)(*args, **module_args)
94 |
95 | def init_ftn(self, name, module, *args, **kwargs):
96 | """
97 | Finds a function handle with the name given as 'type' in config, and returns the
98 | function with given arguments fixed with functools.partial.
99 |
100 | `function = config.init_ftn('name', module, a, b=1)`
101 | is equivalent to
102 | `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`.
103 | """
104 | module_name = self[name]['type']
105 | module_args = dict(self[name]['args'])
106 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
107 | module_args.update(kwargs)
108 | return partial(getattr(module, module_name), *args, **module_args)
109 |
110 | def __getitem__(self, name):
111 | """Access items like ordinary dict."""
112 | return self.config[name]
113 |
114 | def get_logger(self, name, verbosity=2):
115 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys())
116 | assert verbosity in self.log_levels, msg_verbosity
117 | logger = logging.getLogger(name)
118 | logger.setLevel(self.log_levels[verbosity])
119 | return logger
120 |
121 | # setting read-only attributes
122 | @property
123 | def config(self):
124 | return self._config
125 |
126 | @property
127 | def save_dir(self):
128 | return self._save_dir
129 |
130 | @property
131 | def log_dir(self):
132 | return self._log_dir
133 |
134 | # helper functions to update config dict with custom cli options
135 | def _update_config(config, modification):
136 | if modification is None:
137 | return config
138 |
139 | for k, v in modification.items():
140 | if v is not None:
141 | _set_by_path(config, k, v)
142 | return config
143 |
144 | def _get_opt_name(flags):
145 | for flg in flags:
146 | if flg.startswith('--'):
147 | return flg.replace('--', '')
148 | return flags[0].replace('--', '')
149 |
150 | def _set_by_path(tree, keys, value):
151 | """Set a value in a nested object in tree by sequence of keys."""
152 | keys = keys.split(';')
153 | _get_by_path(tree, keys[:-1])[keys[-1]] = value
154 |
155 | def _get_by_path(tree, keys):
156 | """Access a nested object in tree by sequence of keys."""
157 | return reduce(getitem, keys, tree)
158 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.9
2 | torchvision>=0.10
3 | numpy
4 | tqdm
5 | tensorboard>=2.8
6 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import torch
4 | from tqdm import tqdm
5 | import data_loader.data_loaders as module_data
6 | import model.loss as module_loss
7 | import model.metric as module_metric
8 | import model.model as module_arch
9 | from parse_config import ConfigParser
10 | from torch.nn import functional as F
11 | import torchvision.utils as vutils
12 | from torchvision import transforms
13 | from torch.autograd import Variable
14 | import os
15 | import matplotlib.pyplot as plt
16 | from mpl_toolkits.axes_grid1 import ImageGrid
17 |
18 | #Fixes PosixPath Error
19 | import pathlib
20 |
21 | temp = pathlib.PosixPath
22 | pathlib.PosixPath = pathlib.WindowsPath
23 |
24 | def latent_traversal(model, samples, n_changes=5, val_range=(-1, 1)):
25 | """ This function perform latent traversal on a VAE latent space
26 | model_path: str
27 | The absolute path of the model to load
28 | fname: str
29 | The filename to use for saving the latent traversal
30 | samples:
31 | The list of data examples to provide as input of the model
32 | n_changes: int
33 | The number of changes to perform on one latent dimension
34 | val_range: tuple
35 | The range of values that can be set for one latent dimension
36 | """
37 | # TODO: change the next two lines to retrieve the output of your encoder with pytorch
38 | # m = tf.keras.models.load_model(model_path)
39 | z_base = model.encode(samples)[-1]
40 | z_base = z_base.cpu()
41 | # END TODO
42 | r, c = n_changes, z_base.shape[1]
43 | vals = np.linspace(*val_range, r)
44 | shape = samples[0].shape
45 | for j, z in enumerate(z_base):
46 | imgs = np.empty([r * c, *shape])
47 | for i in range(c):
48 | z_iter = np.tile(z, [r, 1])
49 | z_iter[:, i] = vals
50 | z_iter = torch.from_numpy(z_iter)
51 | z_iter = z_iter.to(device)
52 | imgs[r * i:(r * i) + r] = F.sigmoid(model.decode(z_iter)[-1])
53 | plot_traversal(imgs, r, c, shape[-1] == 1, show=True)
54 | # save_figure(fname, tight=False)
55 |
56 |
57 | def plot_traversal(imgs, r, c, greyscale, show=False):
58 | fig = plt.figure(figsize=(20., 20.))
59 | grid = ImageGrid(fig, 111, nrows_ncols=(r, c), axes_pad=0, direction="column")
60 |
61 | for i, (ax, im) in enumerate(zip(grid, imgs)):
62 | ax.set_axis_off()
63 | if i % r == 0:
64 | ax.set_title("z{}".format(i // r), fontdict={'fontsize': 25})
65 | if greyscale is True:
66 | ax.imshow(im, cmap="gray")
67 | else:
68 | ax.imshow(im)
69 |
70 | fig.subplots_adjust(wspace=0, hspace=0)
71 | if show is True:
72 | plt.show()
73 |
74 | plt.savefig('traversal.png')
75 |
76 | def interpolate(autoencoder, x_1, x_2, n=12):
77 | z_1 = autoencoder.encode(x_1)[2]
78 | z_2 = autoencoder.encode(x_2)[2]
79 | z = torch.stack([z_1 + (z_2 - z_1)*t for t in np.linspace(0, 1, n)])
80 | interpolate_list = autoencoder.decode(z)
81 | interpolate_list = interpolate_list.to('cpu').detach()
82 | print(len(interpolate_list))
83 |
84 | plt.figure(figsize=(64, 64))
85 | for i in range(len(interpolate_list)):
86 | ax = plt.subplot(1, len(interpolate_list), i+1)
87 | plt.imshow(interpolate_list[i].permute(1, 2, 0).numpy())
88 | ax.get_xaxis().set_visible(False)
89 | ax.get_yaxis().set_visible(False)
90 | plt.savefig('linear_interpolation.png')
91 |
92 |
93 | def main(config):
94 | logger = config.get_logger('test')
95 |
96 | # setup data_loader instances
97 | data_loader = getattr(module_data, config['data_loader']['type'])(
98 | config['data_loader']['args']['data_dir'],
99 | batch_size=36,
100 | shuffle=False,
101 | validation_split=0.0,
102 | # training=False,
103 | num_workers=2
104 | )
105 |
106 | # build model architecture
107 | model = config.init_obj('arch', module_arch)
108 | logger.info(model)
109 |
110 | # get function handles of loss and metrics
111 | loss_fn = getattr(module_loss, config['loss'])
112 | # metric_fns = [getattr(module_metric, met) for met in config['metrics']]
113 |
114 | logger.info('Loading checkpoint: {} ...'.format(config.resume))
115 | # checkpoint = torch.load(config.resume)
116 |
117 | # loading on CPU-only machine
118 | checkpoint = torch.load(config.resume, map_location=torch.device('cpu'))
119 | state_dict = checkpoint['state_dict']
120 | if config['n_gpu'] > 1:
121 | model = torch.nn.DataParallel(model)
122 | model.load_state_dict(state_dict)
123 |
124 | # prepare model for testing
125 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
126 | model = model.to(device)
127 | model.eval()
128 |
129 | total_loss = 0.0
130 | # total_metrics = torch.zeros(len(metric_fns))
131 |
132 | with torch.no_grad():
133 | for i, (data, target) in enumerate(tqdm(data_loader)):
134 | data, target = data.to(device), target.to(device)
135 | output, mu, logvar = model(data)
136 |
137 | # computing loss, metrics on test set
138 | loss = loss_fn(output, data, mu, logvar)
139 | batch_size = data.shape[0]
140 | total_loss += loss.item() * batch_size
141 | # for i, metric in enumerate(metric_fns):
142 | # total_metrics[i] += metric(output, target) * batch_size
143 |
144 | # Reconstructing and generating images for a mini-batch
145 | test_input, test_label = next(iter(data_loader))
146 | test_input = test_input.to(device)
147 | test_label = test_label.to(device)
148 |
149 | recons = model.generate(test_input, labels=test_label)
150 | vutils.save_image(recons.data,
151 | os.path.join(
152 | "Reconstructions",
153 | f"recons_{logger.name}_epoch_{config['trainer']['epochs']}.png"),
154 | normalize=True,
155 | nrow=6)
156 |
157 | try:
158 | samples = model.sample(36,
159 | device,
160 | labels=test_label)
161 | vutils.save_image(samples.cpu().data,
162 | os.path.join(
163 | "Samples",
164 | f"{logger.name}.png"),
165 | normalize=True,
166 | nrow=6)
167 | except Warning:
168 | pass
169 |
170 | # linear interpolation two chosen images
171 | x_1 = test_input[1].to(device)
172 | x_1 = torch.unsqueeze(x_1, dim=0)
173 | x_2 = test_input[2].to(device)
174 | x_2 = torch.unsqueeze(x_2, dim=0)
175 | interpolate(model, x_1, x_2, n=5)
176 |
177 | n_samples = len(data_loader.sampler)
178 | log = {'loss': total_loss / n_samples}
179 | # log.update({
180 | # met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns)
181 | # })
182 | logger.info(log)
183 |
184 |
185 | if __name__ == '__main__':
186 | args = argparse.ArgumentParser(description='PyTorch Template')
187 | args.add_argument('-c', '--config', default=None, type=str,
188 | help='config file path (default: None)')
189 | args.add_argument('-r', '--resume', default=None, type=str,
190 | help='path to latest checkpoint (default: None)')
191 | args.add_argument('-d', '--device', default=None, type=str,
192 | help='indices of GPUs to enable (default: all)')
193 |
194 | config = ConfigParser.from_args(args)
195 | main(config)
196 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import collections
3 | import torch
4 | import numpy as np
5 | import data_loader.data_loaders as module_data
6 | import model.loss as module_loss
7 | import model.metric as module_metric
8 | import model.model as module_arch
9 | from parse_config import ConfigParser
10 | from trainer import Trainer
11 | from utils import prepare_device
12 |
13 |
14 | # fix random seeds for reproducibility
15 | SEED = 123
16 | torch.manual_seed(SEED)
17 | torch.backends.cudnn.deterministic = True
18 | torch.backends.cudnn.benchmark = False
19 | np.random.seed(SEED)
20 |
21 |
22 | def main(config):
23 | logger = config.get_logger('train')
24 |
25 | # setup data_loader instances
26 | data_loader = config.init_obj('data_loader', module_data)
27 | valid_data_loader = data_loader.split_validation()
28 |
29 | # build model architecture, then print to console
30 | model = config.init_obj('arch', module_arch)
31 | logger.info(model)
32 |
33 | # prepare for (multi-device) GPU training
34 | device, device_ids = prepare_device(config['n_gpu'])
35 | model = model.to(device)
36 | if len(device_ids) > 1:
37 | model = torch.nn.DataParallel(model, device_ids=device_ids)
38 |
39 | # get function handles of loss and metrics
40 | criterion = getattr(module_loss, config['loss'])
41 | # metrics = [getattr(module_metric, met) for met in config['metrics']]
42 |
43 | # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
44 | trainable_params = filter(lambda p: p.requires_grad, model.parameters())
45 | optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
46 | lr_scheduler = config.init_obj(
47 | 'lr_scheduler', torch.optim.lr_scheduler, optimizer)
48 |
49 | trainer = Trainer(model, criterion, optimizer,
50 | config=config,
51 | device=device,
52 | data_loader=data_loader,
53 | valid_data_loader=None,
54 | lr_scheduler=lr_scheduler)
55 |
56 | trainer.train()
57 |
58 |
59 | if __name__ == '__main__':
60 | args = argparse.ArgumentParser(description='PyTorch Template')
61 | args.add_argument('-c', '--config', default=None, type=str,
62 | help='config file path (default: None)')
63 | args.add_argument('-r', '--resume', default=None, type=str,
64 | help='path to latest checkpoint (default: None)')
65 | args.add_argument('-d', '--device', default=None, type=str,
66 | help='indices of GPUs to enable (default: all)')
67 |
68 | # custom cli options to modify configuration from default values given in json file.
69 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
70 | options = [
71 | CustomArgs(['--lr', '--learning_rate'],
72 | type=float, target='optimizer;args;lr'),
73 | CustomArgs(['--bs', '--batch_size'], type=int,
74 | target='data_loader;args;batch_size')
75 | ]
76 | config = ConfigParser.from_args(args, options)
77 | main(config)
78 |
--------------------------------------------------------------------------------
/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainer import *
2 |
--------------------------------------------------------------------------------
/trainer/trainer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torchvision.utils import make_grid
4 | from base import BaseTrainer
5 | from utils import inf_loop, MetricTracker
6 |
7 |
8 | class Trainer(BaseTrainer):
9 | """
10 | Trainer class
11 | """
12 |
13 | def __init__(self, model, criterion, optimizer, config, device,
14 | data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None):
15 | super().__init__(model, criterion, optimizer, config)
16 | self.config = config
17 | self.device = device
18 | self.data_loader = data_loader
19 | if len_epoch is None:
20 | # epoch-based training
21 | self.len_epoch = len(self.data_loader)
22 | else:
23 | # iteration-based training
24 | self.data_loader = inf_loop(data_loader)
25 | self.len_epoch = len_epoch
26 | self.valid_data_loader = valid_data_loader
27 | self.do_validation = self.valid_data_loader is not None
28 | self.lr_scheduler = lr_scheduler
29 | self.log_step = int(np.sqrt(data_loader.batch_size))
30 |
31 | # self.train_metrics = MetricTracker(
32 | # 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
33 | # self.valid_metrics = MetricTracker(
34 | # 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
35 |
36 | self.train_metrics = MetricTracker(
37 | 'loss', writer=self.writer)
38 | self.valid_metrics = MetricTracker(
39 | 'loss', writer=self.writer)
40 |
41 | def _train_epoch(self, epoch):
42 | """
43 | Training logic for an epoch
44 |
45 | :param epoch: Integer, current training epoch.
46 | :return: A log that contains average loss and metric in this epoch.
47 | """
48 | self.model.train()
49 | self.train_metrics.reset()
50 | for batch_idx, (data, target) in enumerate(self.data_loader):
51 | data, target = data.to(self.device), target.to(self.device)
52 |
53 | self.optimizer.zero_grad()
54 | output, mu, logvar = self.model(data)
55 | loss = self.criterion(output, data, mu, logvar)
56 | loss.backward()
57 |
58 | # optional gradient clipping
59 | # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=2.0, norm_type=2)
60 |
61 | self.optimizer.step()
62 | self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
63 | self.train_metrics.update('loss', loss.item())
64 | # for met in self.metric_ftns:
65 | # self.train_metrics.update(met.__name__, met(output, target))
66 |
67 | if batch_idx % self.log_step == 0:
68 | self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
69 | epoch,
70 | self._progress(batch_idx),
71 | loss.item()))
72 | self.writer.add_image('input', make_grid(
73 | data.cpu(), nrow=8, normalize=True))
74 |
75 | if batch_idx == self.len_epoch:
76 | break
77 | log = self.train_metrics.result()
78 |
79 | if self.do_validation:
80 | val_log = self._valid_epoch(epoch)
81 | log.update(**{'val_'+k: v for k, v in val_log.items()})
82 |
83 | if self.lr_scheduler is not None:
84 | self.lr_scheduler.step()
85 | return log
86 |
87 | def _valid_epoch(self, epoch):
88 | """
89 | Validate after training an epoch
90 |
91 | :param epoch: Integer, current training epoch.
92 | :return: A log that contains information about validation
93 | """
94 | self.model.eval()
95 | self.valid_metrics.reset()
96 | with torch.no_grad():
97 | for batch_idx, (data, target) in enumerate(self.valid_data_loader):
98 | data, target = data.to(self.device), target.to(self.device)
99 | output, mu, logvar = self.model(data)
100 | loss = self.criterion(output, data, mu, logvar)
101 |
102 | self.writer.set_step(
103 | (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
104 | self.valid_metrics.update('loss', loss.item())
105 | # for met in self.metric_ftns:
106 | # self.valid_metrics.update(
107 | # met.__name__, met(output, target))
108 | self.writer.add_image('input', make_grid(
109 | data.cpu(), nrow=8, normalize=True))
110 |
111 | # add histogram of model parameters to the tensorboard
112 | for name, p in self.model.named_parameters():
113 | self.writer.add_histogram(name, p, bins='auto')
114 | return self.valid_metrics.result()
115 |
116 | def _progress(self, batch_idx):
117 | base = '[{}/{} ({:.0f}%)]'
118 | if hasattr(self.data_loader, 'n_samples'):
119 | current = batch_idx * self.data_loader.batch_size
120 | total = self.data_loader.n_samples
121 | else:
122 | current = batch_idx
123 | total = self.len_epoch
124 | return base.format(current, total, 100.0 * current / total)
125 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .util import *
2 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import pandas as pd
4 | from pathlib import Path
5 | from itertools import repeat
6 | from collections import OrderedDict
7 |
8 |
9 | def ensure_dir(dirname):
10 | dirname = Path(dirname)
11 | if not dirname.is_dir():
12 | dirname.mkdir(parents=True, exist_ok=False)
13 |
14 | def read_json(fname):
15 | fname = Path(fname)
16 | with fname.open('rt') as handle:
17 | return json.load(handle, object_hook=OrderedDict)
18 |
19 | def write_json(content, fname):
20 | fname = Path(fname)
21 | with fname.open('wt') as handle:
22 | json.dump(content, handle, indent=4, sort_keys=False)
23 |
24 | def inf_loop(data_loader):
25 | ''' wrapper function for endless data loader. '''
26 | for loader in repeat(data_loader):
27 | yield from loader
28 |
29 | def prepare_device(n_gpu_use):
30 | """
31 | setup GPU device if available. get gpu device indices which are used for DataParallel
32 | """
33 | n_gpu = torch.cuda.device_count()
34 | if n_gpu_use > 0 and n_gpu == 0:
35 | print("Warning: There\'s no GPU available on this machine,"
36 | "training will be performed on CPU.")
37 | n_gpu_use = 0
38 | if n_gpu_use > n_gpu:
39 | print(f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are "
40 | "available on this machine.")
41 | n_gpu_use = n_gpu
42 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
43 | list_ids = list(range(n_gpu_use))
44 | return device, list_ids
45 |
46 | class MetricTracker:
47 | def __init__(self, *keys, writer=None):
48 | self.writer = writer
49 | self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average'])
50 | self.reset()
51 |
52 | def reset(self):
53 | for col in self._data.columns:
54 | self._data[col].values[:] = 0
55 |
56 | def update(self, key, value, n=1):
57 | if self.writer is not None:
58 | self.writer.add_scalar(key, value)
59 | self._data.total[key] += value * n
60 | self._data.counts[key] += n
61 | self._data.average[key] = self._data.total[key] / self._data.counts[key]
62 |
63 | def avg(self, key):
64 | return self._data.average[key]
65 |
66 | def result(self):
67 | return dict(self._data.average)
68 |
--------------------------------------------------------------------------------