├── .gitignore ├── LICENSE ├── README.md ├── arts ├── mnist-svhn-exact-replay.png ├── mnist-svhn-generative-replay-sample.jpg ├── mnist-svhn-generative-replay.gif ├── mnist-svhn-generative-replay.png ├── mnist-svhn-none-sample.jpg ├── mnist-svhn-none.gif ├── mnist-svhn-none.png ├── model.png ├── permutated-mnist-exact-replay.png ├── permutated-mnist-generative-replay.png ├── permutated-mnist-none.png ├── svhn-mnist-exact-replay.png ├── svhn-mnist-generative-replay-r0.4-sample.jpg ├── svhn-mnist-generative-replay.gif ├── svhn-mnist-generative-replay.png ├── svhn-mnist-none-r1-sample.jpg ├── svhn-mnist-none.gif └── svhn-mnist-none.png ├── const.py ├── data.py ├── dgr.py ├── gan.py ├── main.py ├── models.py ├── requirements.txt ├── run_full_experiments ├── train.py ├── utils.py └── visual.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | 104 | datasets/ 105 | checkpoints/ 106 | samples/ 107 | .env 108 | _cache_* 109 | .download.cgi* 110 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ha Junsoo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-deep-generative-replay 2 | 3 | PyTorch implementation of [Continual Learning with Deep Generative Replay, NIPS 2017](https://arxiv.org/abs/1705.08690) 4 | 5 | ![model](./arts/model.png) 6 | 7 | 8 | ## Results 9 | 10 | ### Continual Learning on Permutated MNISTs 11 | 12 | - Test precisions **without replay** (*left*), **with exact replay** (*middle*), and **with Deep Generative Replay** (*right*). 13 | 14 | 15 | 16 | ### Continual Learning on MNIST-SVHN 17 | 18 | - Test precisions **without replay** (*left*), **with exact replay** (*middle*), and **with Deep Generative Replay** (*right*). 19 | 20 | 21 | 22 | - Generated samples from the scholar trained **without any replay** (*left*) and **with Deep Generative Replay** (*right*). 23 | 24 | 25 | 26 | - Training scholar's generator **without replay** (*left*) and **with Deep Generative Replay** (*right*). 27 | 28 | 29 | 30 | ### Continual Learning on SVHN-MNIST 31 | 32 | - Test precisions **without replay** (*left*), **with exact replay** (*middle*), and **with Deep Generative Replay** (*right*). 33 | 34 | 35 | 36 | - Generated samples from the scholar trained **without replay** (*left*) and **with Deep Generative Replay** (*right*). 37 | 38 | 39 | 40 | - Training scholar's generator **without replay** (*left*) and **with Deep Generative Replay** (*right*). 41 | 42 | 43 | 44 | 45 | ## Installation 46 | ```shell 47 | $ git clone https://github.com/kuc2477/pytorch-deep-generative-replay 48 | $ pip install -r pytorch-deep-generative-replay/requirements.txt 49 | ``` 50 | 51 | ## Commands 52 | 53 | ### Usage 54 | ```shell 55 | $ ./main.py --help 56 | $ usage: PyTorch implementation of Deep Generative Replay [-h] 57 | [--experiment {permutated-mnist,svhn-mnist,mnist-svhn}] 58 | [--mnist-permutation-number MNIST_PERMUTATION_NUMBER] 59 | [--mnist-permutation-seed MNIST_PERMUTATION_SEED] 60 | --replay-mode 61 | {exact-replay,generative-replay,none} 62 | [--generator-z-size GENERATOR_Z_SIZE] 63 | [--generator-c-channel-size GENERATOR_C_CHANNEL_SIZE] 64 | [--generator-g-channel-size GENERATOR_G_CHANNEL_SIZE] 65 | [--solver-depth SOLVER_DEPTH] 66 | [--solver-reducing-layers SOLVER_REDUCING_LAYERS] 67 | [--solver-channel-size SOLVER_CHANNEL_SIZE] 68 | [--generator-c-updates-per-g-update GENERATOR_C_UPDATES_PER_G_UPDATE] 69 | [--generator-iterations GENERATOR_ITERATIONS] 70 | [--solver-iterations SOLVER_ITERATIONS] 71 | [--importance-of-new-task IMPORTANCE_OF_NEW_TASK] 72 | [--lr LR] 73 | [--weight-decay WEIGHT_DECAY] 74 | [--batch-size BATCH_SIZE] 75 | [--test-size TEST_SIZE] 76 | [--sample-size SAMPLE_SIZE] 77 | [--image-log-interval IMAGE_LOG_INTERVAL] 78 | [--eval-log-interval EVAL_LOG_INTERVAL] 79 | [--loss-log-interval LOSS_LOG_INTERVAL] 80 | [--checkpoint-dir CHECKPOINT_DIR] 81 | [--sample-dir SAMPLE_DIR] 82 | [--no-gpus] 83 | (--train | --test) 84 | 85 | ``` 86 | 87 | ### To Run Full Experiments 88 | ```shell 89 | # Run a visdom server and conduct full experiments 90 | $ python -m visdom.server & 91 | $ ./run_full_experiments 92 | ``` 93 | 94 | ### To Run a Single Experiment 95 | ```shell 96 | # Run a visdom server and conduct a desired experiment 97 | $ python -m visdom.server & 98 | $ ./main.py --train --experiment=[permutated-mnist|svhn-mnist|mnist-svhn] --replay-mode=[exact-replay|generative-replay|none] 99 | ``` 100 | 101 | ### To Generate Images from the learned Scholar 102 | ```shell 103 | $ # Run the command below and visit the "samples" directory 104 | $ ./main.py --test --experiment=[permutated-mnist|svhn-mnist|mnist-svhn] --replay-mode=[exact-replay|generative-replay|none] 105 | ``` 106 | 107 | ## Note 108 | - I failed to find the supplementary materials that the authors mentioned in the paper to contain the experimental details. Thus, I arbitrarily chose a 4-convolutional-layer CNN as a solver in this project. If you know where I can find the additional materials, please let me know through the project's Github issue. 109 | 110 | ## Reference 111 | - [Continual Learning with Deep Generative Replay, arxiv:1705.08690](https://arxiv.org/abs/1705.08690) 112 | 113 | 114 | ## Author 115 | Ha Junsoo / [@kuc2477](https://github.com/kuc2477) / MIT License 116 | -------------------------------------------------------------------------------- /arts/mnist-svhn-exact-replay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/mnist-svhn-exact-replay.png -------------------------------------------------------------------------------- /arts/mnist-svhn-generative-replay-sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/mnist-svhn-generative-replay-sample.jpg -------------------------------------------------------------------------------- /arts/mnist-svhn-generative-replay.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/mnist-svhn-generative-replay.gif -------------------------------------------------------------------------------- /arts/mnist-svhn-generative-replay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/mnist-svhn-generative-replay.png -------------------------------------------------------------------------------- /arts/mnist-svhn-none-sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/mnist-svhn-none-sample.jpg -------------------------------------------------------------------------------- /arts/mnist-svhn-none.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/mnist-svhn-none.gif -------------------------------------------------------------------------------- /arts/mnist-svhn-none.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/mnist-svhn-none.png -------------------------------------------------------------------------------- /arts/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/model.png -------------------------------------------------------------------------------- /arts/permutated-mnist-exact-replay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/permutated-mnist-exact-replay.png -------------------------------------------------------------------------------- /arts/permutated-mnist-generative-replay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/permutated-mnist-generative-replay.png -------------------------------------------------------------------------------- /arts/permutated-mnist-none.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/permutated-mnist-none.png -------------------------------------------------------------------------------- /arts/svhn-mnist-exact-replay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/svhn-mnist-exact-replay.png -------------------------------------------------------------------------------- /arts/svhn-mnist-generative-replay-r0.4-sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/svhn-mnist-generative-replay-r0.4-sample.jpg -------------------------------------------------------------------------------- /arts/svhn-mnist-generative-replay.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/svhn-mnist-generative-replay.gif -------------------------------------------------------------------------------- /arts/svhn-mnist-generative-replay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/svhn-mnist-generative-replay.png -------------------------------------------------------------------------------- /arts/svhn-mnist-none-r1-sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/svhn-mnist-none-r1-sample.jpg -------------------------------------------------------------------------------- /arts/svhn-mnist-none.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/svhn-mnist-none.gif -------------------------------------------------------------------------------- /arts/svhn-mnist-none.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/svhn-mnist-none.png -------------------------------------------------------------------------------- /const.py: -------------------------------------------------------------------------------- 1 | EPSILON = 1e-16 2 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | from torchvision import datasets, transforms 4 | from torchvision.transforms import ImageOps 5 | from torch.utils.data import ConcatDataset 6 | 7 | 8 | def _permutate_image_pixels(image, permutation): 9 | if permutation is None: 10 | return image 11 | 12 | c, h, w = image.size() 13 | image = image.view(-1, c) 14 | image = image[permutation, :] 15 | return image.view(c, h, w) 16 | 17 | 18 | def _colorize_grayscale_image(image): 19 | return ImageOps.colorize(image, (0, 0, 0), (255, 255, 255)) 20 | 21 | 22 | def get_dataset(name, train=True, permutation=None, capacity=None): 23 | dataset = (TRAIN_DATASETS[name] if train else TEST_DATASETS[name])() 24 | dataset.transform = transforms.Compose([ 25 | dataset.transform, 26 | transforms.Lambda(lambda x: _permutate_image_pixels(x, permutation)), 27 | ]) 28 | 29 | if capacity is not None and len(dataset) < capacity: 30 | return ConcatDataset([ 31 | copy.deepcopy(dataset) for _ in 32 | range(math.ceil(capacity / len(dataset))) 33 | ]) 34 | else: 35 | return dataset 36 | 37 | 38 | _MNIST_TRAIN_TRANSFORMS = _MNIST_TEST_TRANSFORMS = [ 39 | transforms.ToTensor(), 40 | transforms.ToPILImage(), 41 | transforms.Pad(2), 42 | transforms.ToTensor(), 43 | ] 44 | 45 | _MNIST_COLORIZED_TRAIN_TRANSFORMS = _MNIST_COLORIZED_TEST_TRANSFORMS = [ 46 | transforms.ToTensor(), 47 | transforms.ToPILImage(), 48 | transforms.Lambda(lambda x: _colorize_grayscale_image(x)), 49 | transforms.Pad(2), 50 | transforms.ToTensor(), 51 | ] 52 | 53 | _CIFAR_TRAIN_TRANSFORMS = _CIFAR_TEST_TRANSFORMS = [ 54 | transforms.ToTensor(), 55 | ] 56 | 57 | _SVHN_TRAIN_TRANSFORMS = _SVHN_TEST_TRANSFORMS = [ 58 | transforms.ToTensor(), 59 | ] 60 | _SVHN_TARGET_TRANSFORMS = [ 61 | transforms.Lambda(lambda y: y % 10) 62 | ] 63 | 64 | 65 | TRAIN_DATASETS = { 66 | 'mnist': lambda: datasets.MNIST( 67 | './datasets/mnist', train=True, download=True, 68 | transform=transforms.Compose(_MNIST_TRAIN_TRANSFORMS) 69 | ), 70 | 'mnist-color': lambda: datasets.MNIST( 71 | './datasets/mnist', train=True, download=True, 72 | transform=transforms.Compose(_MNIST_COLORIZED_TRAIN_TRANSFORMS) 73 | ), 74 | 'cifar10': lambda: datasets.CIFAR10( 75 | './datasets/cifar10', train=True, download=True, 76 | transform=transforms.Compose(_CIFAR_TRAIN_TRANSFORMS) 77 | ), 78 | 'cifar100': lambda: datasets.CIFAR100( 79 | './datasets/cifar100', train=True, download=True, 80 | transform=transforms.Compose(_CIFAR_TRAIN_TRANSFORMS) 81 | ), 82 | 'svhn': lambda: datasets.SVHN( 83 | './datasets/svhn', split='train', download=True, 84 | transform=transforms.Compose(_SVHN_TRAIN_TRANSFORMS), 85 | target_transform=transforms.Compose(_SVHN_TARGET_TRANSFORMS), 86 | ), 87 | } 88 | 89 | 90 | TEST_DATASETS = { 91 | 'mnist': lambda: datasets.MNIST( 92 | './datasets/mnist', train=False, 93 | transform=transforms.Compose(_MNIST_TEST_TRANSFORMS) 94 | ), 95 | 'mnist-color': lambda: datasets.MNIST( 96 | './datasets/mnist', train=False, download=True, 97 | transform=transforms.Compose(_MNIST_COLORIZED_TEST_TRANSFORMS) 98 | ), 99 | 'cifar10': lambda: datasets.CIFAR10( 100 | './datasets/cifar10', train=False, 101 | transform=transforms.Compose(_CIFAR_TEST_TRANSFORMS) 102 | ), 103 | 'cifar100': lambda: datasets.CIFAR100( 104 | './datasets/cifar100', train=False, 105 | transform=transforms.Compose(_CIFAR_TEST_TRANSFORMS) 106 | ), 107 | 'svhn': lambda: datasets.SVHN( 108 | './datasets/svhn', split='test', download=True, 109 | transform=transforms.Compose(_SVHN_TEST_TRANSFORMS), 110 | target_transform=transforms.Compose(_SVHN_TARGET_TRANSFORMS), 111 | ), 112 | } 113 | 114 | 115 | DATASET_CONFIGS = { 116 | 'mnist': {'size': 32, 'channels': 1, 'classes': 10}, 117 | 'mnist-color': {'size': 32, 'channels': 3, 'classes': 10}, 118 | 'cifar10': {'size': 32, 'channels': 3, 'classes': 10}, 119 | 'cifar100': {'size': 32, 'channels': 3, 'classes': 100}, 120 | 'svhn': {'size': 32, 'channels': 3, 'classes': 10}, 121 | 122 | } 123 | -------------------------------------------------------------------------------- /dgr.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import utils 3 | from tqdm import tqdm 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Variable 7 | from torch.utils.data import ConcatDataset 8 | 9 | 10 | # ============ 11 | # Base Classes 12 | # ============ 13 | 14 | class GenerativeMixin(object): 15 | """Mixin which defines a sampling iterface for a generative model.""" 16 | def sample(self, size): 17 | raise NotImplementedError 18 | 19 | 20 | class BatchTrainable(nn.Module, metaclass=abc.ABCMeta): 21 | """ 22 | Abstract base class which defines a generative-replay-based training 23 | interface for a model. 24 | 25 | """ 26 | @abc.abstractmethod 27 | def train_a_batch(self, x, y, x_=None, y_=None, importance_of_new_task=.5): 28 | raise NotImplementedError 29 | 30 | 31 | # ============================== 32 | # Deep Generative Replay Modules 33 | # ============================== 34 | 35 | class Generator(GenerativeMixin, BatchTrainable): 36 | """Abstract generator module of a scholar module""" 37 | 38 | 39 | class Solver(BatchTrainable): 40 | """Abstract solver module of a scholar module""" 41 | def __init__(self): 42 | super().__init__() 43 | self.optimizer = None 44 | self.criterion = None 45 | 46 | @abc.abstractmethod 47 | def forward(self, x): 48 | raise NotImplementedError 49 | 50 | def solve(self, x): 51 | scores = self(x) 52 | _, predictions = torch.max(scores, 1) 53 | return predictions 54 | 55 | def train_a_batch(self, x, y, x_=None, y_=None, importance_of_new_task=.5): 56 | assert x_ is None or x.size() == x_.size() 57 | assert y_ is None or y.size() == y_.size() 58 | 59 | # clear gradients. 60 | batch_size = x.size(0) 61 | self.optimizer.zero_grad() 62 | 63 | # run the model on the real data. 64 | real_scores = self.forward(x) 65 | real_loss = self.criterion(real_scores, y) 66 | _, real_predicted = real_scores.max(1) 67 | real_prec = (y == real_predicted).sum().data[0] / batch_size 68 | 69 | # run the model on the replayed data. 70 | if x_ is not None and y_ is not None: 71 | replay_scores = self.forward(x_) 72 | replay_loss = self.criterion(replay_scores, y_) 73 | _, replay_predicted = replay_scores.max(1) 74 | replay_prec = (y_ == replay_predicted).sum().data[0] / batch_size 75 | 76 | # calculate joint loss of real data and replayed data. 77 | loss = ( 78 | importance_of_new_task * real_loss + 79 | (1-importance_of_new_task) * replay_loss 80 | ) 81 | precision = (real_prec + replay_prec) / 2 82 | else: 83 | loss = real_loss 84 | precision = real_prec 85 | 86 | loss.backward() 87 | self.optimizer.step() 88 | return {'loss': loss.data[0], 'precision': precision} 89 | 90 | def set_optimizer(self, optimizer): 91 | self.optimizer = optimizer 92 | 93 | def set_criterion(self, criterion): 94 | self.criterion = criterion 95 | 96 | 97 | class Scholar(GenerativeMixin, nn.Module): 98 | """Scholar for Deep Generative Replay""" 99 | def __init__(self, label, generator, solver): 100 | super().__init__() 101 | self.label = label 102 | self.generator = generator 103 | self.solver = solver 104 | 105 | def train_with_replay( 106 | self, dataset, scholar=None, previous_datasets=None, 107 | importance_of_new_task=.5, batch_size=32, 108 | generator_iterations=2000, 109 | generator_training_callbacks=None, 110 | solver_iterations=1000, 111 | solver_training_callbacks=None, 112 | collate_fn=None): 113 | # scholar and previous datasets cannot be given at the same time. 114 | mutex_condition_infringed = all([ 115 | scholar is not None, 116 | bool(previous_datasets) 117 | ]) 118 | assert not mutex_condition_infringed, ( 119 | 'scholar and previous datasets cannot be given at the same time' 120 | ) 121 | 122 | # train the generator of the scholar. 123 | self._train_batch_trainable_with_replay( 124 | self.generator, dataset, scholar, 125 | previous_datasets=previous_datasets, 126 | importance_of_new_task=importance_of_new_task, 127 | batch_size=batch_size, 128 | iterations=generator_iterations, 129 | training_callbacks=generator_training_callbacks, 130 | collate_fn=collate_fn, 131 | ) 132 | 133 | # train the solver of the scholar. 134 | self._train_batch_trainable_with_replay( 135 | self.solver, dataset, scholar, 136 | previous_datasets=previous_datasets, 137 | importance_of_new_task=importance_of_new_task, 138 | batch_size=batch_size, 139 | iterations=solver_iterations, 140 | training_callbacks=solver_training_callbacks, 141 | collate_fn=collate_fn, 142 | ) 143 | 144 | @property 145 | def name(self): 146 | return self.label 147 | 148 | def sample(self, size): 149 | x = self.generator.sample(size) 150 | y = self.solver.solve(x) 151 | return x.data, y.data 152 | 153 | def _train_batch_trainable_with_replay( 154 | self, trainable, dataset, scholar=None, previous_datasets=None, 155 | importance_of_new_task=.5, batch_size=32, iterations=1000, 156 | training_callbacks=None, collate_fn=None): 157 | # do not train the model when given non-positive iterations. 158 | if iterations <= 0: 159 | return 160 | 161 | # create data loaders. 162 | data_loader = iter(utils.get_data_loader( 163 | dataset, batch_size, cuda=self._is_on_cuda(), 164 | collate_fn=collate_fn, 165 | )) 166 | data_loader_previous = iter(utils.get_data_loader( 167 | ConcatDataset(previous_datasets), batch_size, 168 | cuda=self._is_on_cuda(), collate_fn=collate_fn, 169 | )) if previous_datasets else None 170 | 171 | # define a tqdm progress bar. 172 | progress = tqdm(range(1, iterations+1)) 173 | 174 | for batch_index in progress: 175 | # decide from where to sample the training data. 176 | from_scholar = scholar is not None 177 | from_previous_datasets = bool(previous_datasets) 178 | cuda = self._is_on_cuda() 179 | 180 | # sample the real training data. 181 | x, y = next(data_loader) 182 | x = Variable(x).cuda() if cuda else Variable(x) 183 | y = Variable(y).cuda() if cuda else Variable(y) 184 | 185 | # sample the replayed training data. 186 | if from_previous_datasets: 187 | x_, y_ = next(data_loader_previous) 188 | elif from_scholar: 189 | x_, y_ = scholar.sample(batch_size) 190 | else: 191 | x_ = y_ = None 192 | 193 | if x_ is not None and y_ is not None: 194 | x_ = Variable(x_).cuda() if cuda else Variable(x_) 195 | y_ = Variable(y_).cuda() if cuda else Variable(y_) 196 | 197 | # train the model with a batch. 198 | result = trainable.train_a_batch( 199 | x, y, x_=x_, y_=y_, 200 | importance_of_new_task=importance_of_new_task 201 | ) 202 | 203 | # fire the callbacks on each iteration. 204 | for callback in (training_callbacks or []): 205 | callback(trainable, progress, batch_index, result) 206 | 207 | def _is_on_cuda(self): 208 | return next(self.parameters()).is_cuda 209 | -------------------------------------------------------------------------------- /gan.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | 4 | 5 | class Critic(nn.Module): 6 | def __init__(self, image_size, image_channel_size, channel_size): 7 | # configurations 8 | super().__init__() 9 | self.image_size = image_size 10 | self.image_channel_size = image_channel_size 11 | self.channel_size = channel_size 12 | 13 | # layers 14 | self.conv1 = nn.Conv2d( 15 | image_channel_size, channel_size, 16 | kernel_size=4, stride=2, padding=1 17 | ) 18 | self.conv2 = nn.Conv2d( 19 | channel_size, channel_size*2, 20 | kernel_size=4, stride=2, padding=1 21 | ) 22 | self.conv3 = nn.Conv2d( 23 | channel_size*2, channel_size*4, 24 | kernel_size=4, stride=2, padding=1 25 | ) 26 | self.conv4 = nn.Conv2d( 27 | channel_size*4, channel_size*8, 28 | kernel_size=4, stride=1, padding=1, 29 | ) 30 | self.fc = nn.Linear((image_size//8)**2 * channel_size*4, 1) 31 | 32 | def forward(self, x): 33 | x = F.leaky_relu(self.conv1(x)) 34 | x = F.leaky_relu(self.conv2(x)) 35 | x = F.leaky_relu(self.conv3(x)) 36 | x = F.leaky_relu(self.conv4(x)) 37 | x = x.view(-1, (self.image_size//8)**2 * self.channel_size*4) 38 | return self.fc(x) 39 | 40 | 41 | class Generator(nn.Module): 42 | def __init__(self, z_size, image_size, image_channel_size, channel_size): 43 | # configurations 44 | super().__init__() 45 | self.z_size = z_size 46 | self.image_size = image_size 47 | self.image_channel_size = image_channel_size 48 | self.channel_size = channel_size 49 | 50 | # layers 51 | self.fc = nn.Linear(z_size, (image_size//8)**2 * channel_size*8) 52 | self.bn0 = nn.BatchNorm2d(channel_size*8) 53 | self.bn1 = nn.BatchNorm2d(channel_size*4) 54 | self.deconv1 = nn.ConvTranspose2d( 55 | channel_size*8, channel_size*4, 56 | kernel_size=4, stride=2, padding=1 57 | ) 58 | self.bn2 = nn.BatchNorm2d(channel_size*2) 59 | self.deconv2 = nn.ConvTranspose2d( 60 | channel_size*4, channel_size*2, 61 | kernel_size=4, stride=2, padding=1, 62 | ) 63 | self.bn3 = nn.BatchNorm2d(channel_size) 64 | self.deconv3 = nn.ConvTranspose2d( 65 | channel_size*2, channel_size, 66 | kernel_size=4, stride=2, padding=1 67 | ) 68 | self.deconv4 = nn.ConvTranspose2d( 69 | channel_size, image_channel_size, 70 | kernel_size=3, stride=1, padding=1 71 | ) 72 | 73 | def forward(self, z): 74 | g = F.relu(self.bn0(self.fc(z).view( 75 | z.size(0), 76 | self.channel_size*8, 77 | self.image_size//8, 78 | self.image_size//8, 79 | ))) 80 | g = F.relu(self.bn1(self.deconv1(g))) 81 | g = F.relu(self.bn2(self.deconv2(g))) 82 | g = F.relu(self.bn3(self.deconv3(g))) 83 | g = self.deconv4(g) 84 | return F.sigmoid(g) 85 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import os.path 4 | import numpy as np 5 | import torch 6 | import utils 7 | from data import get_dataset, DATASET_CONFIGS 8 | from train import train 9 | from dgr import Scholar 10 | from models import WGAN, CNN 11 | 12 | 13 | parser = argparse.ArgumentParser( 14 | 'PyTorch implementation of Deep Generative Replay' 15 | ) 16 | 17 | parser.add_argument( 18 | '--experiment', type=str, 19 | choices=['permutated-mnist', 'svhn-mnist', 'mnist-svhn'], 20 | default='permutated-mnist' 21 | ) 22 | parser.add_argument('--mnist-permutation-number', type=int, default=5) 23 | parser.add_argument('--mnist-permutation-seed', type=int, default=0) 24 | parser.add_argument( 25 | '--replay-mode', type=str, default='generative-replay', 26 | choices=['exact-replay', 'generative-replay', 'none'], 27 | ) 28 | 29 | parser.add_argument('--generator-lambda', type=float, default=10.) 30 | parser.add_argument('--generator-z-size', type=int, default=100) 31 | parser.add_argument('--generator-c-channel-size', type=int, default=64) 32 | parser.add_argument('--generator-g-channel-size', type=int, default=64) 33 | parser.add_argument('--solver-depth', type=int, default=5) 34 | parser.add_argument('--solver-reducing-layers', type=int, default=3) 35 | parser.add_argument('--solver-channel-size', type=int, default=1024) 36 | 37 | parser.add_argument('--generator-c-updates-per-g-update', type=int, default=5) 38 | parser.add_argument('--generator-iterations', type=int, default=3000) 39 | parser.add_argument('--solver-iterations', type=int, default=1000) 40 | parser.add_argument('--importance-of-new-task', type=float, default=.3) 41 | parser.add_argument('--lr', type=float, default=1e-04) 42 | parser.add_argument('--beta1', type=float, default=0.5) 43 | parser.add_argument('--beta2', type=float, default=0.9) 44 | parser.add_argument('--weight-decay', type=float, default=1e-05) 45 | parser.add_argument('--batch-size', type=int, default=32) 46 | parser.add_argument('--test-size', type=int, default=1024) 47 | parser.add_argument('--sample-size', type=int, default=36) 48 | 49 | parser.add_argument('--sample-log', action='store_true') 50 | parser.add_argument('--sample-log-interval', type=int, default=300) 51 | parser.add_argument('--image-log-interval', type=int, default=100) 52 | parser.add_argument('--eval-log-interval', type=int, default=50) 53 | parser.add_argument('--loss-log-interval', type=int, default=30) 54 | parser.add_argument('--checkpoint-dir', type=str, default='./checkpoints') 55 | parser.add_argument('--sample-dir', type=str, default='./samples') 56 | parser.add_argument('--no-gpus', action='store_false', dest='cuda') 57 | 58 | main_command = parser.add_mutually_exclusive_group(required=True) 59 | main_command.add_argument('--train', action='store_true') 60 | main_command.add_argument('--test', action='store_false', dest='train') 61 | 62 | 63 | if __name__ == '__main__': 64 | args = parser.parse_args() 65 | 66 | # decide whether to use cuda or not. 67 | cuda = torch.cuda.is_available() and args.cuda 68 | experiment = args.experiment 69 | capacity = args.batch_size * max( 70 | args.generator_iterations, 71 | args.solver_iterations 72 | ) 73 | 74 | if experiment == 'permutated-mnist': 75 | # generate permutations for the mnist classification tasks. 76 | np.random.seed(args.mnist_permutation_seed) 77 | permutations = [ 78 | np.random.permutation(DATASET_CONFIGS['mnist']['size']**2) for 79 | _ in range(args.mnist_permutation_number) 80 | ] 81 | 82 | # prepare the datasets. 83 | train_datasets = [ 84 | get_dataset('mnist', permutation=p, capacity=capacity) 85 | for p in permutations 86 | ] 87 | test_datasets = [ 88 | get_dataset('mnist', train=False, permutation=p, capacity=capacity) 89 | for p in permutations 90 | ] 91 | 92 | # decide what configuration to use. 93 | dataset_config = DATASET_CONFIGS['mnist'] 94 | 95 | elif experiment in ('svhn-mnist', 'mnist-svhn'): 96 | mnist_color_train = get_dataset( 97 | 'mnist-color', train=True, capacity=capacity 98 | ) 99 | mnist_color_test = get_dataset( 100 | 'mnist-color', train=False, capacity=capacity 101 | ) 102 | svhn_train = get_dataset('svhn', train=True, capacity=capacity) 103 | svhn_test = get_dataset('svhn', train=False, capacity=capacity) 104 | 105 | # prepare the datasets. 106 | train_datasets = ( 107 | [mnist_color_train, svhn_train] if experiment == 'mnist-svhn' else 108 | [svhn_train, mnist_color_train] 109 | ) 110 | test_datasets = ( 111 | [mnist_color_test, svhn_test] if experiment == 'mnist-svhn' else 112 | [svhn_test, mnist_color_test] 113 | ) 114 | 115 | # decide what configuration to use. 116 | dataset_config = DATASET_CONFIGS['mnist-color'] 117 | else: 118 | raise RuntimeError('Given undefined experiment: {}'.format(experiment)) 119 | 120 | # define the models. 121 | cnn = CNN( 122 | image_size=dataset_config['size'], 123 | image_channel_size=dataset_config['channels'], 124 | classes=dataset_config['classes'], 125 | depth=args.solver_depth, 126 | channel_size=args.solver_channel_size, 127 | reducing_layers=args.solver_reducing_layers, 128 | ) 129 | wgan = WGAN( 130 | z_size=args.generator_z_size, 131 | image_size=dataset_config['size'], 132 | image_channel_size=dataset_config['channels'], 133 | c_channel_size=args.generator_c_channel_size, 134 | g_channel_size=args.generator_g_channel_size, 135 | ) 136 | label = '{experiment}-{replay_mode}-r{importance_of_new_task}'.format( 137 | experiment=experiment, 138 | replay_mode=args.replay_mode, 139 | importance_of_new_task=( 140 | 1 if args.replay_mode == 'none' else 141 | args.importance_of_new_task 142 | ), 143 | ) 144 | scholar = Scholar(label, generator=wgan, solver=cnn) 145 | 146 | # initialize the model. 147 | utils.gaussian_intiailize(scholar, std=.02) 148 | 149 | # use cuda if needed 150 | if cuda: 151 | scholar.cuda() 152 | 153 | # determine whether we need to train the generator or not. 154 | train_generator = ( 155 | args.replay_mode == 'generative-replay' or 156 | args.sample_log 157 | ) 158 | 159 | # run the experiment. 160 | if args.train: 161 | train( 162 | scholar, train_datasets, test_datasets, 163 | replay_mode=args.replay_mode, 164 | generator_lambda=args.generator_lambda, 165 | generator_iterations=( 166 | args.generator_iterations if train_generator else 0 167 | ), 168 | generator_c_updates_per_g_update=( 169 | args.generator_c_updates_per_g_update 170 | ), 171 | solver_iterations=args.solver_iterations, 172 | importance_of_new_task=args.importance_of_new_task, 173 | batch_size=args.batch_size, 174 | test_size=args.test_size, 175 | sample_size=args.sample_size, 176 | lr=args.lr, weight_decay=args.weight_decay, 177 | beta1=args.beta1, beta2=args.beta2, 178 | loss_log_interval=args.loss_log_interval, 179 | eval_log_interval=args.eval_log_interval, 180 | image_log_interval=args.image_log_interval, 181 | sample_log_interval=args.sample_log_interval, 182 | sample_log=args.sample_log, 183 | sample_dir=args.sample_dir, 184 | checkpoint_dir=args.checkpoint_dir, 185 | collate_fn=utils.label_squeezing_collate_fn, 186 | cuda=cuda 187 | ) 188 | else: 189 | path = os.path.join(args.sample_dir, '{}-sample'.format(scholar.name)) 190 | utils.load_checkpoint(scholar, args.checkpoint_dir) 191 | utils.test_model(scholar.generator, args.sample_size, path) 192 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import torch 3 | from torch import nn, autograd 4 | from torch.autograd import Variable 5 | import gan 6 | import dgr 7 | import utils 8 | from const import EPSILON 9 | 10 | 11 | class WGAN(dgr.Generator): 12 | def __init__(self, z_size, 13 | image_size, image_channel_size, 14 | c_channel_size, g_channel_size): 15 | # configurations 16 | super().__init__() 17 | self.z_size = z_size 18 | self.image_size = image_size 19 | self.image_channel_size = image_channel_size 20 | self.c_channel_size = c_channel_size 21 | self.g_channel_size = g_channel_size 22 | 23 | # components 24 | self.critic = gan.Critic( 25 | image_size=self.image_size, 26 | image_channel_size=self.image_channel_size, 27 | channel_size=self.c_channel_size, 28 | ) 29 | self.generator = gan.Generator( 30 | z_size=self.z_size, 31 | image_size=self.image_size, 32 | image_channel_size=self.image_channel_size, 33 | channel_size=self.g_channel_size, 34 | ) 35 | 36 | # training related components that should be set before training. 37 | self.generator_optimizer = None 38 | self.critic_optimizer = None 39 | self.critic_updates_per_generator_update = None 40 | self.lamda = None 41 | 42 | def train_a_batch(self, x, y, x_=None, y_=None, importance_of_new_task=.5): 43 | assert x_ is None or x.size() == x_.size() 44 | assert y_ is None or y.size() == y_.size() 45 | 46 | # run the critic and backpropagate the errors. 47 | for _ in range(self.critic_updates_per_generator_update): 48 | self.critic_optimizer.zero_grad() 49 | z = self._noise(x.size(0)) 50 | 51 | # run the critic on the real data. 52 | c_loss_real, g_real = self._c_loss(x, z, return_g=True) 53 | c_loss_real_gp = ( 54 | c_loss_real + self._gradient_penalty(x, g_real, self.lamda) 55 | ) 56 | 57 | # run the critic on the replayed data. 58 | if x_ is not None and y_ is not None: 59 | c_loss_replay, g_replay = self._c_loss(x_, z, return_g=True) 60 | c_loss_replay_gp = (c_loss_replay + self._gradient_penalty( 61 | x_, g_replay, self.lamda 62 | )) 63 | c_loss = ( 64 | importance_of_new_task * c_loss_real + 65 | (1-importance_of_new_task) * c_loss_replay 66 | ) 67 | c_loss_gp = ( 68 | importance_of_new_task * c_loss_real_gp + 69 | (1-importance_of_new_task) * c_loss_replay_gp 70 | ) 71 | else: 72 | c_loss = c_loss_real 73 | c_loss_gp = c_loss_real_gp 74 | 75 | c_loss_gp.backward() 76 | self.critic_optimizer.step() 77 | 78 | # run the generator and backpropagate the errors. 79 | self.generator_optimizer.zero_grad() 80 | z = self._noise(x.size(0)) 81 | g_loss = self._g_loss(z) 82 | g_loss.backward() 83 | self.generator_optimizer.step() 84 | 85 | return {'c_loss': c_loss.data[0], 'g_loss': g_loss.data[0]} 86 | 87 | def sample(self, size): 88 | return self.generator(self._noise(size)) 89 | 90 | def set_generator_optimizer(self, optimizer): 91 | self.generator_optimizer = optimizer 92 | 93 | def set_critic_optimizer(self, optimizer): 94 | self.critic_optimizer = optimizer 95 | 96 | def set_critic_updates_per_generator_update(self, k): 97 | self.critic_updates_per_generator_update = k 98 | 99 | def set_lambda(self, l): 100 | self.lamda = l 101 | 102 | def _noise(self, size): 103 | z = Variable(torch.randn(size, self.z_size)) * .1 104 | return z.cuda() if self._is_on_cuda() else z 105 | 106 | def _c_loss(self, x, z, return_g=False): 107 | g = self.generator(z) 108 | c_x = self.critic(x).mean() 109 | c_g = self.critic(g).mean() 110 | l = -(c_x-c_g) 111 | return (l, g) if return_g else l 112 | 113 | def _g_loss(self, z, return_g=False): 114 | g = self.generator(z) 115 | l = -self.critic(g).mean() 116 | return (l, g) if return_g else l 117 | 118 | def _gradient_penalty(self, x, g, lamda): 119 | assert x.size() == g.size() 120 | a = torch.rand(x.size(0), 1) 121 | a = a.cuda() if self._is_on_cuda() else a 122 | a = a\ 123 | .expand(x.size(0), x.nelement()//x.size(0))\ 124 | .contiguous()\ 125 | .view( 126 | x.size(0), 127 | self.image_channel_size, 128 | self.image_size, 129 | self.image_size 130 | ) 131 | interpolated = Variable(a*x.data + (1-a)*g.data, requires_grad=True) 132 | c = self.critic(interpolated) 133 | gradients = autograd.grad( 134 | c, interpolated, grad_outputs=( 135 | torch.ones(c.size()).cuda() if self._is_on_cuda() else 136 | torch.ones(c.size()) 137 | ), 138 | create_graph=True, 139 | retain_graph=True, 140 | )[0] 141 | return lamda * ((1-(gradients+EPSILON).norm(2, dim=1))**2).mean() 142 | 143 | def _is_on_cuda(self): 144 | return next(self.parameters()).is_cuda 145 | 146 | 147 | class CNN(dgr.Solver): 148 | def __init__(self, 149 | image_size, 150 | image_channel_size, classes, 151 | depth, channel_size, reducing_layers=3): 152 | # configurations 153 | super().__init__() 154 | self.image_size = image_size 155 | self.image_channel_size = image_channel_size 156 | self.classes = classes 157 | self.depth = depth 158 | self.channel_size = channel_size 159 | self.reducing_layers = reducing_layers 160 | 161 | # layers 162 | self.layers = nn.ModuleList([nn.Conv2d( 163 | self.image_channel_size, self.channel_size//(2**(depth-2)), 164 | 3, 1, 1 165 | )]) 166 | 167 | for i in range(self.depth-2): 168 | previous_conv = [ 169 | l for l in self.layers if 170 | isinstance(l, nn.Conv2d) 171 | ][-1] 172 | self.layers.append(nn.Conv2d( 173 | previous_conv.out_channels, 174 | previous_conv.out_channels * 2, 175 | 3, 1 if i >= reducing_layers else 2, 1 176 | )) 177 | self.layers.append(nn.BatchNorm2d(previous_conv.out_channels * 2)) 178 | self.layers.append(nn.ReLU()) 179 | 180 | self.layers.append(utils.LambdaModule(lambda x: x.view(x.size(0), -1))) 181 | self.layers.append(nn.Linear( 182 | (image_size//(2**reducing_layers))**2 * channel_size, 183 | self.classes 184 | )) 185 | 186 | def forward(self, x): 187 | return reduce(lambda x, l: l(x), self.layers, x) 188 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # pytorch 2 | http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp35-cp35m-manylinux1_x86_64.whl 3 | torchvision==0.1.9 4 | torchtext==0.1.1 5 | visdom 6 | 7 | # data 8 | scipy 9 | scikit-learn 10 | numpy 11 | pillow 12 | 13 | # utils (others) 14 | colorama 15 | tqdm 16 | lmdb 17 | requests 18 | fake-useragent 19 | -------------------------------------------------------------------------------- /run_full_experiments: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # ================ 4 | # Permutated MNIST 5 | # ================ 6 | 7 | PERM_MNIST_GENERATOR_ITERATIONS=3000 8 | PERM_MNIST_SOLVER_ITERATIONS=1000 9 | PERM_MNIST_LR=0.0001 10 | PERM_MNIST_IMPORTANCE_OF_NEW_TASK=0.3 11 | 12 | ./main.py --train \ 13 | --experiment=permutated-mnist \ 14 | --replay-mode=exact-replay \ 15 | --solver-iterations=$PERM_MNIST_SOLVER_ITERATIONS \ 16 | --lr=$PERM_MNIST_LR \ 17 | --importance-of-new-task=$PERM_MNIST_IMPORTANCE_OF_NEW_TASK 18 | 19 | ./main.py --train \ 20 | --experiment=permutated-mnist \ 21 | --replay-mode=generative-replay \ 22 | --generator-iterations=$PERM_MNIST_GENERATOR_ITERATIONS \ 23 | --solver-iterations=$PERM_MNIST_SOLVER_ITERATIONS \ 24 | --lr=$PERM_MNIST_LR \ 25 | --importance-of-new-task=$PERM_MNIST_IMPORTANCE_OF_NEW_TASK 26 | 27 | ./main.py --train \ 28 | --experiment=permutated-mnist \ 29 | --replay-mode=none \ 30 | --solver-iterations=$PERM_MNIST_SOLVER_ITERATIONS \ 31 | --lr=$PERM_MNIST_LR 32 | 33 | # ========== 34 | # MNIST-SVHN 35 | # ========== 36 | 37 | MNIST_SVHN_GENERATOR_ITERATIONS=20000 38 | MNIST_SVHN_SOLVER_ITERATIONS=4000 39 | MNIST_SVHN_LR=0.00003 40 | MNIST_SVHN_IMPORTANCE_OF_NEW_TASK=0.4 41 | 42 | ./main.py --train \ 43 | --experiment=mnist-svhn \ 44 | --replay-mode=exact-replay \ 45 | --solver-iterations=$MNIST_SVHN_SOLVER_ITERATIONS \ 46 | --importance-of-new-task=$MNIST_SVHN_IMPORTANCE_OF_NEW_TASK \ 47 | --lr=$MNIST_SVHN_LR 48 | 49 | ./main.py --train \ 50 | --experiment=mnist-svhn \ 51 | --replay-mode=generative-replay \ 52 | --generator-iterations=$MNIST_SVHN_GENERATOR_ITERATIONS \ 53 | --solver-iterations=$MNIST_SVHN_SOLVER_ITERATIONS \ 54 | --importance-of-new-task=$MNIST_SVHN_IMPORTANCE_OF_NEW_TASK \ 55 | --lr=$MNIST_SVHN_LR \ 56 | --sample-log 57 | 58 | ./main.py --train \ 59 | --experiment=mnist-svhn \ 60 | --replay-mode=none \ 61 | --generator-iterations=$MNIST_SVHN_GENERATOR_ITERATIONS \ 62 | --solver-iterations=$MNIST_SVHN_SOLVER_ITERATIONS \ 63 | --lr=$MNIST_SVHN_LR \ 64 | --sample-log 65 | 66 | # ========== 67 | # SVHN-MNIST 68 | # ========== 69 | 70 | SVHN_MNIST_GENERATOR_ITERATIONS=20000 71 | SVHN_MNIST_SOLVER_ITERATIONS=4000 72 | SVHN_MNIST_LR=0.00003 73 | SVHN_MNIST_IMPORTANCE_OF_NEW_TASK=0.4 74 | 75 | ./main.py --train \ 76 | --experiment=svhn-mnist \ 77 | --replay-mode=exact-replay \ 78 | --solver-iterations=$SVHN_MNIST_SOLVER_ITERATIONS \ 79 | --importance-of-new-task=$SVHN_MNIST_IMPORTANCE_OF_NEW_TASK \ 80 | --lr=$SVHN_MNIST_LR 81 | 82 | ./main.py --train \ 83 | --experiment=svhn-mnist \ 84 | --replay-mode=generative-replay \ 85 | --generator-iterations=$SVHN_MNIST_GENERATOR_ITERATIONS \ 86 | --solver-iterations=$SVHN_MNIST_SOLVER_ITERATIONS \ 87 | --importance-of-new-task=$SVHN_MNIST_IMPORTANCE_OF_NEW_TASK \ 88 | --lr=$SVHN_MNIST_LR \ 89 | --sample-log 90 | 91 | ./main.py --train \ 92 | --experiment=svhn-mnist \ 93 | --replay-mode=none \ 94 | --generator-iterations=$SVHN_MNIST_GENERATOR_ITERATIONS \ 95 | --solver-iterations=$SVHN_MNIST_SOLVER_ITERATIONS \ 96 | --lr=$SVHN_MNIST_LR \ 97 | --sample-log 98 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import copy 3 | from torch import optim 4 | from torch import nn 5 | import utils 6 | import visual 7 | 8 | 9 | def train(scholar, train_datasets, test_datasets, replay_mode, 10 | generator_lambda=10., 11 | generator_c_updates_per_g_update=5, 12 | generator_iterations=2000, 13 | solver_iterations=1000, 14 | importance_of_new_task=.5, 15 | batch_size=32, 16 | test_size=1024, 17 | sample_size=36, 18 | lr=1e-03, weight_decay=1e-05, 19 | beta1=.5, beta2=.9, 20 | loss_log_interval=30, 21 | eval_log_interval=50, 22 | image_log_interval=100, 23 | sample_log_interval=300, 24 | sample_log=False, 25 | sample_dir='./samples', 26 | checkpoint_dir='./checkpoints', 27 | collate_fn=None, 28 | cuda=False): 29 | # define solver criterion and generators for the scholar model. 30 | solver_criterion = nn.CrossEntropyLoss() 31 | solver_optimizer = optim.Adam( 32 | scholar.solver.parameters(), 33 | lr=lr, weight_decay=weight_decay, betas=(beta1, beta2), 34 | ) 35 | generator_g_optimizer = optim.Adam( 36 | scholar.generator.generator.parameters(), 37 | lr=lr, weight_decay=weight_decay, betas=(beta1, beta2), 38 | ) 39 | generator_c_optimizer = optim.Adam( 40 | scholar.generator.critic.parameters(), 41 | lr=lr, weight_decay=weight_decay, betas=(beta1, beta2), 42 | ) 43 | 44 | # set the criterion, optimizers, and training configurations for the 45 | # scholar model. 46 | scholar.solver.set_criterion(solver_criterion) 47 | scholar.solver.set_optimizer(solver_optimizer) 48 | scholar.generator.set_lambda(generator_lambda) 49 | scholar.generator.set_generator_optimizer(generator_g_optimizer) 50 | scholar.generator.set_critic_optimizer(generator_c_optimizer) 51 | scholar.generator.set_critic_updates_per_generator_update( 52 | generator_c_updates_per_g_update 53 | ) 54 | scholar.train() 55 | 56 | # define the previous scholar who will generate samples of previous tasks. 57 | previous_scholar = None 58 | previous_datasets = None 59 | 60 | for task, train_dataset in enumerate(train_datasets, 1): 61 | # define callbacks for visualizing the training process. 62 | generator_training_callbacks = [_generator_training_callback( 63 | loss_log_interval=loss_log_interval, 64 | image_log_interval=image_log_interval, 65 | sample_log_interval=sample_log_interval, 66 | sample_log=sample_log, 67 | sample_dir=sample_dir, 68 | sample_size=sample_size, 69 | current_task=task, 70 | total_tasks=len(train_datasets), 71 | total_iterations=generator_iterations, 72 | batch_size=batch_size, 73 | replay_mode=replay_mode, 74 | env=scholar.name, 75 | )] 76 | solver_training_callbacks = [_solver_training_callback( 77 | loss_log_interval=loss_log_interval, 78 | eval_log_interval=eval_log_interval, 79 | current_task=task, 80 | total_tasks=len(train_datasets), 81 | total_iterations=solver_iterations, 82 | batch_size=batch_size, 83 | test_size=test_size, 84 | test_datasets=test_datasets, 85 | replay_mode=replay_mode, 86 | cuda=cuda, 87 | collate_fn=collate_fn, 88 | env=scholar.name, 89 | )] 90 | 91 | # train the scholar with generative replay. 92 | scholar.train_with_replay( 93 | train_dataset, 94 | scholar=previous_scholar, 95 | previous_datasets=previous_datasets, 96 | importance_of_new_task=importance_of_new_task, 97 | batch_size=batch_size, 98 | generator_iterations=generator_iterations, 99 | generator_training_callbacks=generator_training_callbacks, 100 | solver_iterations=solver_iterations, 101 | solver_training_callbacks=solver_training_callbacks, 102 | collate_fn=collate_fn, 103 | ) 104 | 105 | previous_scholar = ( 106 | copy.deepcopy(scholar) if replay_mode == 'generative-replay' else 107 | None 108 | ) 109 | previous_datasets = ( 110 | train_datasets[:task] if replay_mode == 'exact-replay' else 111 | None 112 | ) 113 | 114 | # save the model after the experiment. 115 | print() 116 | utils.save_checkpoint(scholar, checkpoint_dir) 117 | print() 118 | print() 119 | 120 | 121 | def _generator_training_callback( 122 | loss_log_interval, 123 | image_log_interval, 124 | sample_log_interval, 125 | sample_log, 126 | sample_dir, 127 | current_task, 128 | total_tasks, 129 | total_iterations, 130 | batch_size, 131 | sample_size, 132 | replay_mode, 133 | env): 134 | 135 | def cb(generator, progress, batch_index, result): 136 | iteration = (current_task-1)*total_iterations + batch_index 137 | progress.set_description(( 138 | ' ' 139 | 'task: {task}/{tasks} | ' 140 | 'progress: [{trained}/{total}] ({percentage:.0f}%) | ' 141 | 'loss => ' 142 | 'g: {g_loss:.4} / ' 143 | 'w: {w_dist:.4}' 144 | ).format( 145 | task=current_task, 146 | tasks=total_tasks, 147 | trained=batch_size * batch_index, 148 | total=batch_size * total_iterations, 149 | percentage=(100.*batch_index/total_iterations), 150 | g_loss=result['g_loss'], 151 | w_dist=-result['c_loss'], 152 | )) 153 | 154 | # log the losses of the generator. 155 | if iteration % loss_log_interval == 0: 156 | visual.visualize_scalar( 157 | result['g_loss'], 'generator g loss', iteration, env=env 158 | ) 159 | visual.visualize_scalar( 160 | -result['c_loss'], 'generator w distance', iteration, env=env 161 | ) 162 | 163 | # log the generated images of the generator. 164 | if iteration % image_log_interval == 0: 165 | visual.visualize_images( 166 | generator.sample(sample_size).data, 167 | 'generated samples ({replay_mode})' 168 | .format(replay_mode=replay_mode), env=env, 169 | ) 170 | 171 | # log the sample images of the generator 172 | if iteration % sample_log_interval == 0 and sample_log: 173 | utils.test_model(generator, sample_size, os.path.join( 174 | sample_dir, 175 | env + '-sample-logs', 176 | str(iteration) 177 | ), verbose=False) 178 | 179 | return cb 180 | 181 | 182 | def _solver_training_callback( 183 | loss_log_interval, 184 | eval_log_interval, 185 | current_task, 186 | total_tasks, 187 | total_iterations, 188 | batch_size, 189 | test_size, 190 | test_datasets, 191 | cuda, 192 | replay_mode, 193 | collate_fn, 194 | env): 195 | 196 | def cb(solver, progress, batch_index, result): 197 | iteration = (current_task-1)*total_iterations + batch_index 198 | progress.set_description(( 199 | ' ' 200 | 'task: {task}/{tasks} | ' 201 | 'progress: [{trained}/{total}] ({percentage:.0f}%) | ' 202 | 'loss: {loss:.4} | ' 203 | 'prec: {prec:.4}' 204 | ).format( 205 | task=current_task, 206 | tasks=total_tasks, 207 | trained=batch_size * batch_index, 208 | total=batch_size * total_iterations, 209 | percentage=(100.*batch_index/total_iterations), 210 | loss=result['loss'], 211 | prec=result['precision'], 212 | )) 213 | 214 | # log the loss of the solver. 215 | if iteration % loss_log_interval == 0: 216 | visual.visualize_scalar( 217 | result['loss'], 'solver loss', iteration, env=env 218 | ) 219 | 220 | # evaluate the solver on multiple tasks. 221 | if iteration % eval_log_interval == 0: 222 | names = ['task {}'.format(i+1) for i in range(len(test_datasets))] 223 | precs = [ 224 | utils.validate( 225 | solver, test_datasets[i], test_size=test_size, 226 | cuda=cuda, verbose=False, collate_fn=collate_fn, 227 | ) if i+1 <= current_task else 0 for i in 228 | range(len(test_datasets)) 229 | ] 230 | title = 'precision ({replay_mode})'.format(replay_mode=replay_mode) 231 | visual.visualize_scalars( 232 | precs, names, title, 233 | iteration, env=env 234 | ) 235 | 236 | return cb 237 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import torchvision 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data.dataloader import default_collate 9 | 10 | 11 | def label_squeezing_collate_fn(batch): 12 | x, y = default_collate(batch) 13 | return x, y.long().squeeze() 14 | 15 | 16 | def get_data_loader(dataset, batch_size, cuda=False, collate_fn=None): 17 | return DataLoader( 18 | dataset, batch_size=batch_size, shuffle=True, 19 | collate_fn=(collate_fn or default_collate), 20 | **({'num_workers': 0, 'pin_memory': True} if cuda else {}) 21 | ) 22 | 23 | 24 | def save_checkpoint(model, model_dir): 25 | path = os.path.join(model_dir, model.name) 26 | 27 | # save the checkpoint. 28 | if not os.path.exists(model_dir): 29 | os.makedirs(model_dir) 30 | 31 | torch.save({'state': model.state_dict()}, path) 32 | 33 | # notify that we successfully saved the checkpoint. 34 | print('=> saved the model {name} to {path}'.format( 35 | name=model.name, path=path 36 | )) 37 | 38 | 39 | def load_checkpoint(model, model_dir): 40 | path = os.path.join(model_dir, model.name) 41 | 42 | # load the checkpoint. 43 | checkpoint = torch.load(path) 44 | print('=> loaded checkpoint of {name} from {path}'.format( 45 | name=model.name, path=path 46 | )) 47 | 48 | # load parameters and return the checkpoint's epoch and precision. 49 | model.load_state_dict(checkpoint['state']) 50 | 51 | 52 | def test_model(model, sample_size, path, verbose=True): 53 | os.makedirs(os.path.dirname(path), exist_ok=True) 54 | torchvision.utils.save_image( 55 | model.sample(sample_size).data, 56 | path + '.jpg', 57 | nrow=6, 58 | ) 59 | if verbose: 60 | print('=> generated sample images at "{}".'.format(path)) 61 | 62 | 63 | def validate(model, dataset, test_size=1024, 64 | cuda=False, verbose=True, collate_fn=None): 65 | data_loader = get_data_loader( 66 | dataset, 128, cuda=cuda, 67 | collate_fn=(collate_fn or default_collate), 68 | ) 69 | total_tested = 0 70 | total_correct = 0 71 | for data, labels in data_loader: 72 | # break on test size. 73 | if total_tested >= test_size: 74 | break 75 | # test the model. 76 | data = Variable(data).cuda() if cuda else Variable(data) 77 | labels = Variable(labels).cuda() if cuda else Variable(labels) 78 | scores = model(data) 79 | _, predicted = torch.max(scores, 1) 80 | 81 | # update statistics. 82 | total_correct += (predicted == labels).sum().data[0] 83 | total_tested += len(data) 84 | 85 | precision = total_correct / total_tested 86 | if verbose: 87 | print('=> precision: {:.3f}'.format(precision)) 88 | return precision 89 | 90 | 91 | def xavier_initialize(model): 92 | modules = [m for n, m in model.named_modules() if 'conv' in n or 'fc' in n] 93 | parameters = [p for m in modules for p in m.parameters()] 94 | 95 | for p in parameters: 96 | if p.dim() >= 2: 97 | nn.init.xavier_normal(p) 98 | else: 99 | nn.init.constant(p, 0) 100 | 101 | 102 | def gaussian_intiailize(model, std=.01): 103 | modules = [m for n, m in model.named_modules() if 'conv' in n or 'fc' in n] 104 | parameters = [p for m in modules for p in m.parameters()] 105 | 106 | for p in parameters: 107 | if p.dim() >= 2: 108 | nn.init.normal(p, std=std) 109 | else: 110 | nn.init.constant(p, 0) 111 | 112 | 113 | class LambdaModule(nn.Module): 114 | def __init__(self, f): 115 | super().__init__() 116 | self.f = f 117 | 118 | def forward(self, x): 119 | return self.f(x) 120 | -------------------------------------------------------------------------------- /visual.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.cuda import FloatTensor as CUDATensor 3 | from visdom import Visdom 4 | 5 | _WINDOW_CASH = {} 6 | 7 | 8 | def _vis(env='main'): 9 | return Visdom(env=env) 10 | 11 | 12 | def visualize_image(tensor, name, label=None, env='main', w=250, h=250, 13 | update_window_without_label=False): 14 | tensor = tensor.cpu() if isinstance(tensor, CUDATensor) else tensor 15 | title = name + ('-{}'.format(label) if label is not None else '') 16 | 17 | _WINDOW_CASH[title] = _vis(env).image( 18 | tensor.numpy(), win=_WINDOW_CASH.get(title), 19 | opts=dict(title=title, width=w, height=h) 20 | ) 21 | 22 | # This is useful when you want to maintain the most recent images. 23 | if update_window_without_label: 24 | _WINDOW_CASH[name] = _vis(env).image( 25 | tensor.numpy(), win=_WINDOW_CASH.get(name), 26 | opts=dict(title=name, width=w, height=h) 27 | ) 28 | 29 | 30 | def visualize_images(tensor, name, label=None, env='main', w=400, h=400, 31 | update_window_without_label=False): 32 | tensor = tensor.cpu() if isinstance(tensor, CUDATensor) else tensor 33 | title = name + ('-{}'.format(label) if label is not None else '') 34 | 35 | _WINDOW_CASH[title] = _vis(env).images( 36 | tensor.numpy(), win=_WINDOW_CASH.get(title), nrow=6, 37 | opts=dict(title=title, width=w, height=h) 38 | ) 39 | 40 | # This is useful when you want to maintain the most recent images. 41 | if update_window_without_label: 42 | _WINDOW_CASH[name] = _vis(env).images( 43 | tensor.numpy(), win=_WINDOW_CASH.get(name), nrow=6, 44 | opts=dict(title=name, width=w, height=h) 45 | ) 46 | 47 | 48 | def visualize_kernel(kernel, name, label=None, env='main', w=250, h=250, 49 | update_window_without_label=False, compress_tensor=False): 50 | # Do not visualize kernels that does not exists. 51 | if kernel is None: 52 | return 53 | 54 | assert len(kernel.size()) in (2, 4) 55 | title = name + ('-{}'.format(label) if label is not None else '') 56 | kernel = kernel.cpu() if isinstance(kernel, CUDATensor) else kernel 57 | kernel_norm = kernel if len(kernel.size()) == 2 else ( 58 | (kernel**2).mean(-1).mean(-1) if compress_tensor else 59 | kernel.view( 60 | kernel.size()[0] * kernel.size()[2], 61 | kernel.size()[1] * kernel.size()[3], 62 | ) 63 | ) 64 | kernel_norm = kernel_norm.abs() 65 | 66 | visualized = ( 67 | (kernel_norm - kernel_norm.min()) / 68 | (kernel_norm.max() - kernel_norm.min()) 69 | ).numpy() 70 | 71 | _WINDOW_CASH[title] = _vis(env).image( 72 | visualized, win=_WINDOW_CASH.get(title), 73 | opts=dict(title=title, width=w, height=h) 74 | ) 75 | 76 | # This is useful when you want to maintain the most recent images. 77 | if update_window_without_label: 78 | _WINDOW_CASH[name] = _vis(env).image( 79 | visualized, win=_WINDOW_CASH.get(name), 80 | opts=dict(title=name, width=w, height=h) 81 | ) 82 | 83 | 84 | def visualize_scalar(scalar, name, iteration, env='main'): 85 | visualize_scalars( 86 | [scalar] if isinstance(scalar, float) or len(scalar) == 1 else scalar, 87 | [name], name, iteration, env=env 88 | ) 89 | 90 | 91 | def visualize_scalars(scalars, names, title, iteration, env='main'): 92 | assert len(scalars) == len(names) 93 | # Convert scalar tensors to numpy arrays. 94 | scalars, names = list(scalars), list(names) 95 | scalars = [s.cpu() if isinstance(s, CUDATensor) else s for s in scalars] 96 | scalars = [s.numpy() if hasattr(s, 'numpy') else np.array([s]) for s in 97 | scalars] 98 | multi = len(scalars) > 1 99 | num = len(scalars) 100 | 101 | options = dict( 102 | fillarea=True, 103 | legend=names, 104 | width=400, 105 | height=400, 106 | xlabel='Iterations', 107 | ylabel=title, 108 | title=title, 109 | marginleft=30, 110 | marginright=30, 111 | marginbottom=80, 112 | margintop=30, 113 | ) 114 | 115 | X = ( 116 | np.column_stack(np.array([iteration] * num)) if multi else 117 | np.array([iteration] * num) 118 | ) 119 | Y = np.column_stack(scalars) if multi else scalars[0] 120 | 121 | if title in _WINDOW_CASH: 122 | _vis(env).updateTrace(X=X, Y=Y, win=_WINDOW_CASH[title], opts=options) 123 | else: 124 | _WINDOW_CASH[title] = _vis(env).line(X=X, Y=Y, opts=options) 125 | --------------------------------------------------------------------------------