├── .gitignore
├── LICENSE
├── README.md
├── arts
├── mnist-svhn-exact-replay.png
├── mnist-svhn-generative-replay-sample.jpg
├── mnist-svhn-generative-replay.gif
├── mnist-svhn-generative-replay.png
├── mnist-svhn-none-sample.jpg
├── mnist-svhn-none.gif
├── mnist-svhn-none.png
├── model.png
├── permutated-mnist-exact-replay.png
├── permutated-mnist-generative-replay.png
├── permutated-mnist-none.png
├── svhn-mnist-exact-replay.png
├── svhn-mnist-generative-replay-r0.4-sample.jpg
├── svhn-mnist-generative-replay.gif
├── svhn-mnist-generative-replay.png
├── svhn-mnist-none-r1-sample.jpg
├── svhn-mnist-none.gif
└── svhn-mnist-none.png
├── const.py
├── data.py
├── dgr.py
├── gan.py
├── main.py
├── models.py
├── requirements.txt
├── run_full_experiments
├── train.py
├── utils.py
└── visual.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 |
104 | datasets/
105 | checkpoints/
106 | samples/
107 | .env
108 | _cache_*
109 | .download.cgi*
110 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Ha Junsoo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-deep-generative-replay
2 |
3 | PyTorch implementation of [Continual Learning with Deep Generative Replay, NIPS 2017](https://arxiv.org/abs/1705.08690)
4 |
5 | 
6 |
7 |
8 | ## Results
9 |
10 | ### Continual Learning on Permutated MNISTs
11 |
12 | - Test precisions **without replay** (*left*), **with exact replay** (*middle*), and **with Deep Generative Replay** (*right*).
13 |
14 |
15 |
16 | ### Continual Learning on MNIST-SVHN
17 |
18 | - Test precisions **without replay** (*left*), **with exact replay** (*middle*), and **with Deep Generative Replay** (*right*).
19 |
20 |
21 |
22 | - Generated samples from the scholar trained **without any replay** (*left*) and **with Deep Generative Replay** (*right*).
23 |
24 |
25 |
26 | - Training scholar's generator **without replay** (*left*) and **with Deep Generative Replay** (*right*).
27 |
28 |
29 |
30 | ### Continual Learning on SVHN-MNIST
31 |
32 | - Test precisions **without replay** (*left*), **with exact replay** (*middle*), and **with Deep Generative Replay** (*right*).
33 |
34 |
35 |
36 | - Generated samples from the scholar trained **without replay** (*left*) and **with Deep Generative Replay** (*right*).
37 |
38 |
39 |
40 | - Training scholar's generator **without replay** (*left*) and **with Deep Generative Replay** (*right*).
41 |
42 |
43 |
44 |
45 | ## Installation
46 | ```shell
47 | $ git clone https://github.com/kuc2477/pytorch-deep-generative-replay
48 | $ pip install -r pytorch-deep-generative-replay/requirements.txt
49 | ```
50 |
51 | ## Commands
52 |
53 | ### Usage
54 | ```shell
55 | $ ./main.py --help
56 | $ usage: PyTorch implementation of Deep Generative Replay [-h]
57 | [--experiment {permutated-mnist,svhn-mnist,mnist-svhn}]
58 | [--mnist-permutation-number MNIST_PERMUTATION_NUMBER]
59 | [--mnist-permutation-seed MNIST_PERMUTATION_SEED]
60 | --replay-mode
61 | {exact-replay,generative-replay,none}
62 | [--generator-z-size GENERATOR_Z_SIZE]
63 | [--generator-c-channel-size GENERATOR_C_CHANNEL_SIZE]
64 | [--generator-g-channel-size GENERATOR_G_CHANNEL_SIZE]
65 | [--solver-depth SOLVER_DEPTH]
66 | [--solver-reducing-layers SOLVER_REDUCING_LAYERS]
67 | [--solver-channel-size SOLVER_CHANNEL_SIZE]
68 | [--generator-c-updates-per-g-update GENERATOR_C_UPDATES_PER_G_UPDATE]
69 | [--generator-iterations GENERATOR_ITERATIONS]
70 | [--solver-iterations SOLVER_ITERATIONS]
71 | [--importance-of-new-task IMPORTANCE_OF_NEW_TASK]
72 | [--lr LR]
73 | [--weight-decay WEIGHT_DECAY]
74 | [--batch-size BATCH_SIZE]
75 | [--test-size TEST_SIZE]
76 | [--sample-size SAMPLE_SIZE]
77 | [--image-log-interval IMAGE_LOG_INTERVAL]
78 | [--eval-log-interval EVAL_LOG_INTERVAL]
79 | [--loss-log-interval LOSS_LOG_INTERVAL]
80 | [--checkpoint-dir CHECKPOINT_DIR]
81 | [--sample-dir SAMPLE_DIR]
82 | [--no-gpus]
83 | (--train | --test)
84 |
85 | ```
86 |
87 | ### To Run Full Experiments
88 | ```shell
89 | # Run a visdom server and conduct full experiments
90 | $ python -m visdom.server &
91 | $ ./run_full_experiments
92 | ```
93 |
94 | ### To Run a Single Experiment
95 | ```shell
96 | # Run a visdom server and conduct a desired experiment
97 | $ python -m visdom.server &
98 | $ ./main.py --train --experiment=[permutated-mnist|svhn-mnist|mnist-svhn] --replay-mode=[exact-replay|generative-replay|none]
99 | ```
100 |
101 | ### To Generate Images from the learned Scholar
102 | ```shell
103 | $ # Run the command below and visit the "samples" directory
104 | $ ./main.py --test --experiment=[permutated-mnist|svhn-mnist|mnist-svhn] --replay-mode=[exact-replay|generative-replay|none]
105 | ```
106 |
107 | ## Note
108 | - I failed to find the supplementary materials that the authors mentioned in the paper to contain the experimental details. Thus, I arbitrarily chose a 4-convolutional-layer CNN as a solver in this project. If you know where I can find the additional materials, please let me know through the project's Github issue.
109 |
110 | ## Reference
111 | - [Continual Learning with Deep Generative Replay, arxiv:1705.08690](https://arxiv.org/abs/1705.08690)
112 |
113 |
114 | ## Author
115 | Ha Junsoo / [@kuc2477](https://github.com/kuc2477) / MIT License
116 |
--------------------------------------------------------------------------------
/arts/mnist-svhn-exact-replay.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/mnist-svhn-exact-replay.png
--------------------------------------------------------------------------------
/arts/mnist-svhn-generative-replay-sample.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/mnist-svhn-generative-replay-sample.jpg
--------------------------------------------------------------------------------
/arts/mnist-svhn-generative-replay.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/mnist-svhn-generative-replay.gif
--------------------------------------------------------------------------------
/arts/mnist-svhn-generative-replay.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/mnist-svhn-generative-replay.png
--------------------------------------------------------------------------------
/arts/mnist-svhn-none-sample.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/mnist-svhn-none-sample.jpg
--------------------------------------------------------------------------------
/arts/mnist-svhn-none.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/mnist-svhn-none.gif
--------------------------------------------------------------------------------
/arts/mnist-svhn-none.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/mnist-svhn-none.png
--------------------------------------------------------------------------------
/arts/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/model.png
--------------------------------------------------------------------------------
/arts/permutated-mnist-exact-replay.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/permutated-mnist-exact-replay.png
--------------------------------------------------------------------------------
/arts/permutated-mnist-generative-replay.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/permutated-mnist-generative-replay.png
--------------------------------------------------------------------------------
/arts/permutated-mnist-none.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/permutated-mnist-none.png
--------------------------------------------------------------------------------
/arts/svhn-mnist-exact-replay.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/svhn-mnist-exact-replay.png
--------------------------------------------------------------------------------
/arts/svhn-mnist-generative-replay-r0.4-sample.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/svhn-mnist-generative-replay-r0.4-sample.jpg
--------------------------------------------------------------------------------
/arts/svhn-mnist-generative-replay.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/svhn-mnist-generative-replay.gif
--------------------------------------------------------------------------------
/arts/svhn-mnist-generative-replay.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/svhn-mnist-generative-replay.png
--------------------------------------------------------------------------------
/arts/svhn-mnist-none-r1-sample.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/svhn-mnist-none-r1-sample.jpg
--------------------------------------------------------------------------------
/arts/svhn-mnist-none.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/svhn-mnist-none.gif
--------------------------------------------------------------------------------
/arts/svhn-mnist-none.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-deep-generative-replay/b2d82e45b69876ddab0f7ce38737888e17e377f2/arts/svhn-mnist-none.png
--------------------------------------------------------------------------------
/const.py:
--------------------------------------------------------------------------------
1 | EPSILON = 1e-16
2 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | from torchvision import datasets, transforms
4 | from torchvision.transforms import ImageOps
5 | from torch.utils.data import ConcatDataset
6 |
7 |
8 | def _permutate_image_pixels(image, permutation):
9 | if permutation is None:
10 | return image
11 |
12 | c, h, w = image.size()
13 | image = image.view(-1, c)
14 | image = image[permutation, :]
15 | return image.view(c, h, w)
16 |
17 |
18 | def _colorize_grayscale_image(image):
19 | return ImageOps.colorize(image, (0, 0, 0), (255, 255, 255))
20 |
21 |
22 | def get_dataset(name, train=True, permutation=None, capacity=None):
23 | dataset = (TRAIN_DATASETS[name] if train else TEST_DATASETS[name])()
24 | dataset.transform = transforms.Compose([
25 | dataset.transform,
26 | transforms.Lambda(lambda x: _permutate_image_pixels(x, permutation)),
27 | ])
28 |
29 | if capacity is not None and len(dataset) < capacity:
30 | return ConcatDataset([
31 | copy.deepcopy(dataset) for _ in
32 | range(math.ceil(capacity / len(dataset)))
33 | ])
34 | else:
35 | return dataset
36 |
37 |
38 | _MNIST_TRAIN_TRANSFORMS = _MNIST_TEST_TRANSFORMS = [
39 | transforms.ToTensor(),
40 | transforms.ToPILImage(),
41 | transforms.Pad(2),
42 | transforms.ToTensor(),
43 | ]
44 |
45 | _MNIST_COLORIZED_TRAIN_TRANSFORMS = _MNIST_COLORIZED_TEST_TRANSFORMS = [
46 | transforms.ToTensor(),
47 | transforms.ToPILImage(),
48 | transforms.Lambda(lambda x: _colorize_grayscale_image(x)),
49 | transforms.Pad(2),
50 | transforms.ToTensor(),
51 | ]
52 |
53 | _CIFAR_TRAIN_TRANSFORMS = _CIFAR_TEST_TRANSFORMS = [
54 | transforms.ToTensor(),
55 | ]
56 |
57 | _SVHN_TRAIN_TRANSFORMS = _SVHN_TEST_TRANSFORMS = [
58 | transforms.ToTensor(),
59 | ]
60 | _SVHN_TARGET_TRANSFORMS = [
61 | transforms.Lambda(lambda y: y % 10)
62 | ]
63 |
64 |
65 | TRAIN_DATASETS = {
66 | 'mnist': lambda: datasets.MNIST(
67 | './datasets/mnist', train=True, download=True,
68 | transform=transforms.Compose(_MNIST_TRAIN_TRANSFORMS)
69 | ),
70 | 'mnist-color': lambda: datasets.MNIST(
71 | './datasets/mnist', train=True, download=True,
72 | transform=transforms.Compose(_MNIST_COLORIZED_TRAIN_TRANSFORMS)
73 | ),
74 | 'cifar10': lambda: datasets.CIFAR10(
75 | './datasets/cifar10', train=True, download=True,
76 | transform=transforms.Compose(_CIFAR_TRAIN_TRANSFORMS)
77 | ),
78 | 'cifar100': lambda: datasets.CIFAR100(
79 | './datasets/cifar100', train=True, download=True,
80 | transform=transforms.Compose(_CIFAR_TRAIN_TRANSFORMS)
81 | ),
82 | 'svhn': lambda: datasets.SVHN(
83 | './datasets/svhn', split='train', download=True,
84 | transform=transforms.Compose(_SVHN_TRAIN_TRANSFORMS),
85 | target_transform=transforms.Compose(_SVHN_TARGET_TRANSFORMS),
86 | ),
87 | }
88 |
89 |
90 | TEST_DATASETS = {
91 | 'mnist': lambda: datasets.MNIST(
92 | './datasets/mnist', train=False,
93 | transform=transforms.Compose(_MNIST_TEST_TRANSFORMS)
94 | ),
95 | 'mnist-color': lambda: datasets.MNIST(
96 | './datasets/mnist', train=False, download=True,
97 | transform=transforms.Compose(_MNIST_COLORIZED_TEST_TRANSFORMS)
98 | ),
99 | 'cifar10': lambda: datasets.CIFAR10(
100 | './datasets/cifar10', train=False,
101 | transform=transforms.Compose(_CIFAR_TEST_TRANSFORMS)
102 | ),
103 | 'cifar100': lambda: datasets.CIFAR100(
104 | './datasets/cifar100', train=False,
105 | transform=transforms.Compose(_CIFAR_TEST_TRANSFORMS)
106 | ),
107 | 'svhn': lambda: datasets.SVHN(
108 | './datasets/svhn', split='test', download=True,
109 | transform=transforms.Compose(_SVHN_TEST_TRANSFORMS),
110 | target_transform=transforms.Compose(_SVHN_TARGET_TRANSFORMS),
111 | ),
112 | }
113 |
114 |
115 | DATASET_CONFIGS = {
116 | 'mnist': {'size': 32, 'channels': 1, 'classes': 10},
117 | 'mnist-color': {'size': 32, 'channels': 3, 'classes': 10},
118 | 'cifar10': {'size': 32, 'channels': 3, 'classes': 10},
119 | 'cifar100': {'size': 32, 'channels': 3, 'classes': 100},
120 | 'svhn': {'size': 32, 'channels': 3, 'classes': 10},
121 |
122 | }
123 |
--------------------------------------------------------------------------------
/dgr.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import utils
3 | from tqdm import tqdm
4 | import torch
5 | from torch import nn
6 | from torch.autograd import Variable
7 | from torch.utils.data import ConcatDataset
8 |
9 |
10 | # ============
11 | # Base Classes
12 | # ============
13 |
14 | class GenerativeMixin(object):
15 | """Mixin which defines a sampling iterface for a generative model."""
16 | def sample(self, size):
17 | raise NotImplementedError
18 |
19 |
20 | class BatchTrainable(nn.Module, metaclass=abc.ABCMeta):
21 | """
22 | Abstract base class which defines a generative-replay-based training
23 | interface for a model.
24 |
25 | """
26 | @abc.abstractmethod
27 | def train_a_batch(self, x, y, x_=None, y_=None, importance_of_new_task=.5):
28 | raise NotImplementedError
29 |
30 |
31 | # ==============================
32 | # Deep Generative Replay Modules
33 | # ==============================
34 |
35 | class Generator(GenerativeMixin, BatchTrainable):
36 | """Abstract generator module of a scholar module"""
37 |
38 |
39 | class Solver(BatchTrainable):
40 | """Abstract solver module of a scholar module"""
41 | def __init__(self):
42 | super().__init__()
43 | self.optimizer = None
44 | self.criterion = None
45 |
46 | @abc.abstractmethod
47 | def forward(self, x):
48 | raise NotImplementedError
49 |
50 | def solve(self, x):
51 | scores = self(x)
52 | _, predictions = torch.max(scores, 1)
53 | return predictions
54 |
55 | def train_a_batch(self, x, y, x_=None, y_=None, importance_of_new_task=.5):
56 | assert x_ is None or x.size() == x_.size()
57 | assert y_ is None or y.size() == y_.size()
58 |
59 | # clear gradients.
60 | batch_size = x.size(0)
61 | self.optimizer.zero_grad()
62 |
63 | # run the model on the real data.
64 | real_scores = self.forward(x)
65 | real_loss = self.criterion(real_scores, y)
66 | _, real_predicted = real_scores.max(1)
67 | real_prec = (y == real_predicted).sum().data[0] / batch_size
68 |
69 | # run the model on the replayed data.
70 | if x_ is not None and y_ is not None:
71 | replay_scores = self.forward(x_)
72 | replay_loss = self.criterion(replay_scores, y_)
73 | _, replay_predicted = replay_scores.max(1)
74 | replay_prec = (y_ == replay_predicted).sum().data[0] / batch_size
75 |
76 | # calculate joint loss of real data and replayed data.
77 | loss = (
78 | importance_of_new_task * real_loss +
79 | (1-importance_of_new_task) * replay_loss
80 | )
81 | precision = (real_prec + replay_prec) / 2
82 | else:
83 | loss = real_loss
84 | precision = real_prec
85 |
86 | loss.backward()
87 | self.optimizer.step()
88 | return {'loss': loss.data[0], 'precision': precision}
89 |
90 | def set_optimizer(self, optimizer):
91 | self.optimizer = optimizer
92 |
93 | def set_criterion(self, criterion):
94 | self.criterion = criterion
95 |
96 |
97 | class Scholar(GenerativeMixin, nn.Module):
98 | """Scholar for Deep Generative Replay"""
99 | def __init__(self, label, generator, solver):
100 | super().__init__()
101 | self.label = label
102 | self.generator = generator
103 | self.solver = solver
104 |
105 | def train_with_replay(
106 | self, dataset, scholar=None, previous_datasets=None,
107 | importance_of_new_task=.5, batch_size=32,
108 | generator_iterations=2000,
109 | generator_training_callbacks=None,
110 | solver_iterations=1000,
111 | solver_training_callbacks=None,
112 | collate_fn=None):
113 | # scholar and previous datasets cannot be given at the same time.
114 | mutex_condition_infringed = all([
115 | scholar is not None,
116 | bool(previous_datasets)
117 | ])
118 | assert not mutex_condition_infringed, (
119 | 'scholar and previous datasets cannot be given at the same time'
120 | )
121 |
122 | # train the generator of the scholar.
123 | self._train_batch_trainable_with_replay(
124 | self.generator, dataset, scholar,
125 | previous_datasets=previous_datasets,
126 | importance_of_new_task=importance_of_new_task,
127 | batch_size=batch_size,
128 | iterations=generator_iterations,
129 | training_callbacks=generator_training_callbacks,
130 | collate_fn=collate_fn,
131 | )
132 |
133 | # train the solver of the scholar.
134 | self._train_batch_trainable_with_replay(
135 | self.solver, dataset, scholar,
136 | previous_datasets=previous_datasets,
137 | importance_of_new_task=importance_of_new_task,
138 | batch_size=batch_size,
139 | iterations=solver_iterations,
140 | training_callbacks=solver_training_callbacks,
141 | collate_fn=collate_fn,
142 | )
143 |
144 | @property
145 | def name(self):
146 | return self.label
147 |
148 | def sample(self, size):
149 | x = self.generator.sample(size)
150 | y = self.solver.solve(x)
151 | return x.data, y.data
152 |
153 | def _train_batch_trainable_with_replay(
154 | self, trainable, dataset, scholar=None, previous_datasets=None,
155 | importance_of_new_task=.5, batch_size=32, iterations=1000,
156 | training_callbacks=None, collate_fn=None):
157 | # do not train the model when given non-positive iterations.
158 | if iterations <= 0:
159 | return
160 |
161 | # create data loaders.
162 | data_loader = iter(utils.get_data_loader(
163 | dataset, batch_size, cuda=self._is_on_cuda(),
164 | collate_fn=collate_fn,
165 | ))
166 | data_loader_previous = iter(utils.get_data_loader(
167 | ConcatDataset(previous_datasets), batch_size,
168 | cuda=self._is_on_cuda(), collate_fn=collate_fn,
169 | )) if previous_datasets else None
170 |
171 | # define a tqdm progress bar.
172 | progress = tqdm(range(1, iterations+1))
173 |
174 | for batch_index in progress:
175 | # decide from where to sample the training data.
176 | from_scholar = scholar is not None
177 | from_previous_datasets = bool(previous_datasets)
178 | cuda = self._is_on_cuda()
179 |
180 | # sample the real training data.
181 | x, y = next(data_loader)
182 | x = Variable(x).cuda() if cuda else Variable(x)
183 | y = Variable(y).cuda() if cuda else Variable(y)
184 |
185 | # sample the replayed training data.
186 | if from_previous_datasets:
187 | x_, y_ = next(data_loader_previous)
188 | elif from_scholar:
189 | x_, y_ = scholar.sample(batch_size)
190 | else:
191 | x_ = y_ = None
192 |
193 | if x_ is not None and y_ is not None:
194 | x_ = Variable(x_).cuda() if cuda else Variable(x_)
195 | y_ = Variable(y_).cuda() if cuda else Variable(y_)
196 |
197 | # train the model with a batch.
198 | result = trainable.train_a_batch(
199 | x, y, x_=x_, y_=y_,
200 | importance_of_new_task=importance_of_new_task
201 | )
202 |
203 | # fire the callbacks on each iteration.
204 | for callback in (training_callbacks or []):
205 | callback(trainable, progress, batch_index, result)
206 |
207 | def _is_on_cuda(self):
208 | return next(self.parameters()).is_cuda
209 |
--------------------------------------------------------------------------------
/gan.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.nn import functional as F
3 |
4 |
5 | class Critic(nn.Module):
6 | def __init__(self, image_size, image_channel_size, channel_size):
7 | # configurations
8 | super().__init__()
9 | self.image_size = image_size
10 | self.image_channel_size = image_channel_size
11 | self.channel_size = channel_size
12 |
13 | # layers
14 | self.conv1 = nn.Conv2d(
15 | image_channel_size, channel_size,
16 | kernel_size=4, stride=2, padding=1
17 | )
18 | self.conv2 = nn.Conv2d(
19 | channel_size, channel_size*2,
20 | kernel_size=4, stride=2, padding=1
21 | )
22 | self.conv3 = nn.Conv2d(
23 | channel_size*2, channel_size*4,
24 | kernel_size=4, stride=2, padding=1
25 | )
26 | self.conv4 = nn.Conv2d(
27 | channel_size*4, channel_size*8,
28 | kernel_size=4, stride=1, padding=1,
29 | )
30 | self.fc = nn.Linear((image_size//8)**2 * channel_size*4, 1)
31 |
32 | def forward(self, x):
33 | x = F.leaky_relu(self.conv1(x))
34 | x = F.leaky_relu(self.conv2(x))
35 | x = F.leaky_relu(self.conv3(x))
36 | x = F.leaky_relu(self.conv4(x))
37 | x = x.view(-1, (self.image_size//8)**2 * self.channel_size*4)
38 | return self.fc(x)
39 |
40 |
41 | class Generator(nn.Module):
42 | def __init__(self, z_size, image_size, image_channel_size, channel_size):
43 | # configurations
44 | super().__init__()
45 | self.z_size = z_size
46 | self.image_size = image_size
47 | self.image_channel_size = image_channel_size
48 | self.channel_size = channel_size
49 |
50 | # layers
51 | self.fc = nn.Linear(z_size, (image_size//8)**2 * channel_size*8)
52 | self.bn0 = nn.BatchNorm2d(channel_size*8)
53 | self.bn1 = nn.BatchNorm2d(channel_size*4)
54 | self.deconv1 = nn.ConvTranspose2d(
55 | channel_size*8, channel_size*4,
56 | kernel_size=4, stride=2, padding=1
57 | )
58 | self.bn2 = nn.BatchNorm2d(channel_size*2)
59 | self.deconv2 = nn.ConvTranspose2d(
60 | channel_size*4, channel_size*2,
61 | kernel_size=4, stride=2, padding=1,
62 | )
63 | self.bn3 = nn.BatchNorm2d(channel_size)
64 | self.deconv3 = nn.ConvTranspose2d(
65 | channel_size*2, channel_size,
66 | kernel_size=4, stride=2, padding=1
67 | )
68 | self.deconv4 = nn.ConvTranspose2d(
69 | channel_size, image_channel_size,
70 | kernel_size=3, stride=1, padding=1
71 | )
72 |
73 | def forward(self, z):
74 | g = F.relu(self.bn0(self.fc(z).view(
75 | z.size(0),
76 | self.channel_size*8,
77 | self.image_size//8,
78 | self.image_size//8,
79 | )))
80 | g = F.relu(self.bn1(self.deconv1(g)))
81 | g = F.relu(self.bn2(self.deconv2(g)))
82 | g = F.relu(self.bn3(self.deconv3(g)))
83 | g = self.deconv4(g)
84 | return F.sigmoid(g)
85 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import os.path
4 | import numpy as np
5 | import torch
6 | import utils
7 | from data import get_dataset, DATASET_CONFIGS
8 | from train import train
9 | from dgr import Scholar
10 | from models import WGAN, CNN
11 |
12 |
13 | parser = argparse.ArgumentParser(
14 | 'PyTorch implementation of Deep Generative Replay'
15 | )
16 |
17 | parser.add_argument(
18 | '--experiment', type=str,
19 | choices=['permutated-mnist', 'svhn-mnist', 'mnist-svhn'],
20 | default='permutated-mnist'
21 | )
22 | parser.add_argument('--mnist-permutation-number', type=int, default=5)
23 | parser.add_argument('--mnist-permutation-seed', type=int, default=0)
24 | parser.add_argument(
25 | '--replay-mode', type=str, default='generative-replay',
26 | choices=['exact-replay', 'generative-replay', 'none'],
27 | )
28 |
29 | parser.add_argument('--generator-lambda', type=float, default=10.)
30 | parser.add_argument('--generator-z-size', type=int, default=100)
31 | parser.add_argument('--generator-c-channel-size', type=int, default=64)
32 | parser.add_argument('--generator-g-channel-size', type=int, default=64)
33 | parser.add_argument('--solver-depth', type=int, default=5)
34 | parser.add_argument('--solver-reducing-layers', type=int, default=3)
35 | parser.add_argument('--solver-channel-size', type=int, default=1024)
36 |
37 | parser.add_argument('--generator-c-updates-per-g-update', type=int, default=5)
38 | parser.add_argument('--generator-iterations', type=int, default=3000)
39 | parser.add_argument('--solver-iterations', type=int, default=1000)
40 | parser.add_argument('--importance-of-new-task', type=float, default=.3)
41 | parser.add_argument('--lr', type=float, default=1e-04)
42 | parser.add_argument('--beta1', type=float, default=0.5)
43 | parser.add_argument('--beta2', type=float, default=0.9)
44 | parser.add_argument('--weight-decay', type=float, default=1e-05)
45 | parser.add_argument('--batch-size', type=int, default=32)
46 | parser.add_argument('--test-size', type=int, default=1024)
47 | parser.add_argument('--sample-size', type=int, default=36)
48 |
49 | parser.add_argument('--sample-log', action='store_true')
50 | parser.add_argument('--sample-log-interval', type=int, default=300)
51 | parser.add_argument('--image-log-interval', type=int, default=100)
52 | parser.add_argument('--eval-log-interval', type=int, default=50)
53 | parser.add_argument('--loss-log-interval', type=int, default=30)
54 | parser.add_argument('--checkpoint-dir', type=str, default='./checkpoints')
55 | parser.add_argument('--sample-dir', type=str, default='./samples')
56 | parser.add_argument('--no-gpus', action='store_false', dest='cuda')
57 |
58 | main_command = parser.add_mutually_exclusive_group(required=True)
59 | main_command.add_argument('--train', action='store_true')
60 | main_command.add_argument('--test', action='store_false', dest='train')
61 |
62 |
63 | if __name__ == '__main__':
64 | args = parser.parse_args()
65 |
66 | # decide whether to use cuda or not.
67 | cuda = torch.cuda.is_available() and args.cuda
68 | experiment = args.experiment
69 | capacity = args.batch_size * max(
70 | args.generator_iterations,
71 | args.solver_iterations
72 | )
73 |
74 | if experiment == 'permutated-mnist':
75 | # generate permutations for the mnist classification tasks.
76 | np.random.seed(args.mnist_permutation_seed)
77 | permutations = [
78 | np.random.permutation(DATASET_CONFIGS['mnist']['size']**2) for
79 | _ in range(args.mnist_permutation_number)
80 | ]
81 |
82 | # prepare the datasets.
83 | train_datasets = [
84 | get_dataset('mnist', permutation=p, capacity=capacity)
85 | for p in permutations
86 | ]
87 | test_datasets = [
88 | get_dataset('mnist', train=False, permutation=p, capacity=capacity)
89 | for p in permutations
90 | ]
91 |
92 | # decide what configuration to use.
93 | dataset_config = DATASET_CONFIGS['mnist']
94 |
95 | elif experiment in ('svhn-mnist', 'mnist-svhn'):
96 | mnist_color_train = get_dataset(
97 | 'mnist-color', train=True, capacity=capacity
98 | )
99 | mnist_color_test = get_dataset(
100 | 'mnist-color', train=False, capacity=capacity
101 | )
102 | svhn_train = get_dataset('svhn', train=True, capacity=capacity)
103 | svhn_test = get_dataset('svhn', train=False, capacity=capacity)
104 |
105 | # prepare the datasets.
106 | train_datasets = (
107 | [mnist_color_train, svhn_train] if experiment == 'mnist-svhn' else
108 | [svhn_train, mnist_color_train]
109 | )
110 | test_datasets = (
111 | [mnist_color_test, svhn_test] if experiment == 'mnist-svhn' else
112 | [svhn_test, mnist_color_test]
113 | )
114 |
115 | # decide what configuration to use.
116 | dataset_config = DATASET_CONFIGS['mnist-color']
117 | else:
118 | raise RuntimeError('Given undefined experiment: {}'.format(experiment))
119 |
120 | # define the models.
121 | cnn = CNN(
122 | image_size=dataset_config['size'],
123 | image_channel_size=dataset_config['channels'],
124 | classes=dataset_config['classes'],
125 | depth=args.solver_depth,
126 | channel_size=args.solver_channel_size,
127 | reducing_layers=args.solver_reducing_layers,
128 | )
129 | wgan = WGAN(
130 | z_size=args.generator_z_size,
131 | image_size=dataset_config['size'],
132 | image_channel_size=dataset_config['channels'],
133 | c_channel_size=args.generator_c_channel_size,
134 | g_channel_size=args.generator_g_channel_size,
135 | )
136 | label = '{experiment}-{replay_mode}-r{importance_of_new_task}'.format(
137 | experiment=experiment,
138 | replay_mode=args.replay_mode,
139 | importance_of_new_task=(
140 | 1 if args.replay_mode == 'none' else
141 | args.importance_of_new_task
142 | ),
143 | )
144 | scholar = Scholar(label, generator=wgan, solver=cnn)
145 |
146 | # initialize the model.
147 | utils.gaussian_intiailize(scholar, std=.02)
148 |
149 | # use cuda if needed
150 | if cuda:
151 | scholar.cuda()
152 |
153 | # determine whether we need to train the generator or not.
154 | train_generator = (
155 | args.replay_mode == 'generative-replay' or
156 | args.sample_log
157 | )
158 |
159 | # run the experiment.
160 | if args.train:
161 | train(
162 | scholar, train_datasets, test_datasets,
163 | replay_mode=args.replay_mode,
164 | generator_lambda=args.generator_lambda,
165 | generator_iterations=(
166 | args.generator_iterations if train_generator else 0
167 | ),
168 | generator_c_updates_per_g_update=(
169 | args.generator_c_updates_per_g_update
170 | ),
171 | solver_iterations=args.solver_iterations,
172 | importance_of_new_task=args.importance_of_new_task,
173 | batch_size=args.batch_size,
174 | test_size=args.test_size,
175 | sample_size=args.sample_size,
176 | lr=args.lr, weight_decay=args.weight_decay,
177 | beta1=args.beta1, beta2=args.beta2,
178 | loss_log_interval=args.loss_log_interval,
179 | eval_log_interval=args.eval_log_interval,
180 | image_log_interval=args.image_log_interval,
181 | sample_log_interval=args.sample_log_interval,
182 | sample_log=args.sample_log,
183 | sample_dir=args.sample_dir,
184 | checkpoint_dir=args.checkpoint_dir,
185 | collate_fn=utils.label_squeezing_collate_fn,
186 | cuda=cuda
187 | )
188 | else:
189 | path = os.path.join(args.sample_dir, '{}-sample'.format(scholar.name))
190 | utils.load_checkpoint(scholar, args.checkpoint_dir)
191 | utils.test_model(scholar.generator, args.sample_size, path)
192 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 | import torch
3 | from torch import nn, autograd
4 | from torch.autograd import Variable
5 | import gan
6 | import dgr
7 | import utils
8 | from const import EPSILON
9 |
10 |
11 | class WGAN(dgr.Generator):
12 | def __init__(self, z_size,
13 | image_size, image_channel_size,
14 | c_channel_size, g_channel_size):
15 | # configurations
16 | super().__init__()
17 | self.z_size = z_size
18 | self.image_size = image_size
19 | self.image_channel_size = image_channel_size
20 | self.c_channel_size = c_channel_size
21 | self.g_channel_size = g_channel_size
22 |
23 | # components
24 | self.critic = gan.Critic(
25 | image_size=self.image_size,
26 | image_channel_size=self.image_channel_size,
27 | channel_size=self.c_channel_size,
28 | )
29 | self.generator = gan.Generator(
30 | z_size=self.z_size,
31 | image_size=self.image_size,
32 | image_channel_size=self.image_channel_size,
33 | channel_size=self.g_channel_size,
34 | )
35 |
36 | # training related components that should be set before training.
37 | self.generator_optimizer = None
38 | self.critic_optimizer = None
39 | self.critic_updates_per_generator_update = None
40 | self.lamda = None
41 |
42 | def train_a_batch(self, x, y, x_=None, y_=None, importance_of_new_task=.5):
43 | assert x_ is None or x.size() == x_.size()
44 | assert y_ is None or y.size() == y_.size()
45 |
46 | # run the critic and backpropagate the errors.
47 | for _ in range(self.critic_updates_per_generator_update):
48 | self.critic_optimizer.zero_grad()
49 | z = self._noise(x.size(0))
50 |
51 | # run the critic on the real data.
52 | c_loss_real, g_real = self._c_loss(x, z, return_g=True)
53 | c_loss_real_gp = (
54 | c_loss_real + self._gradient_penalty(x, g_real, self.lamda)
55 | )
56 |
57 | # run the critic on the replayed data.
58 | if x_ is not None and y_ is not None:
59 | c_loss_replay, g_replay = self._c_loss(x_, z, return_g=True)
60 | c_loss_replay_gp = (c_loss_replay + self._gradient_penalty(
61 | x_, g_replay, self.lamda
62 | ))
63 | c_loss = (
64 | importance_of_new_task * c_loss_real +
65 | (1-importance_of_new_task) * c_loss_replay
66 | )
67 | c_loss_gp = (
68 | importance_of_new_task * c_loss_real_gp +
69 | (1-importance_of_new_task) * c_loss_replay_gp
70 | )
71 | else:
72 | c_loss = c_loss_real
73 | c_loss_gp = c_loss_real_gp
74 |
75 | c_loss_gp.backward()
76 | self.critic_optimizer.step()
77 |
78 | # run the generator and backpropagate the errors.
79 | self.generator_optimizer.zero_grad()
80 | z = self._noise(x.size(0))
81 | g_loss = self._g_loss(z)
82 | g_loss.backward()
83 | self.generator_optimizer.step()
84 |
85 | return {'c_loss': c_loss.data[0], 'g_loss': g_loss.data[0]}
86 |
87 | def sample(self, size):
88 | return self.generator(self._noise(size))
89 |
90 | def set_generator_optimizer(self, optimizer):
91 | self.generator_optimizer = optimizer
92 |
93 | def set_critic_optimizer(self, optimizer):
94 | self.critic_optimizer = optimizer
95 |
96 | def set_critic_updates_per_generator_update(self, k):
97 | self.critic_updates_per_generator_update = k
98 |
99 | def set_lambda(self, l):
100 | self.lamda = l
101 |
102 | def _noise(self, size):
103 | z = Variable(torch.randn(size, self.z_size)) * .1
104 | return z.cuda() if self._is_on_cuda() else z
105 |
106 | def _c_loss(self, x, z, return_g=False):
107 | g = self.generator(z)
108 | c_x = self.critic(x).mean()
109 | c_g = self.critic(g).mean()
110 | l = -(c_x-c_g)
111 | return (l, g) if return_g else l
112 |
113 | def _g_loss(self, z, return_g=False):
114 | g = self.generator(z)
115 | l = -self.critic(g).mean()
116 | return (l, g) if return_g else l
117 |
118 | def _gradient_penalty(self, x, g, lamda):
119 | assert x.size() == g.size()
120 | a = torch.rand(x.size(0), 1)
121 | a = a.cuda() if self._is_on_cuda() else a
122 | a = a\
123 | .expand(x.size(0), x.nelement()//x.size(0))\
124 | .contiguous()\
125 | .view(
126 | x.size(0),
127 | self.image_channel_size,
128 | self.image_size,
129 | self.image_size
130 | )
131 | interpolated = Variable(a*x.data + (1-a)*g.data, requires_grad=True)
132 | c = self.critic(interpolated)
133 | gradients = autograd.grad(
134 | c, interpolated, grad_outputs=(
135 | torch.ones(c.size()).cuda() if self._is_on_cuda() else
136 | torch.ones(c.size())
137 | ),
138 | create_graph=True,
139 | retain_graph=True,
140 | )[0]
141 | return lamda * ((1-(gradients+EPSILON).norm(2, dim=1))**2).mean()
142 |
143 | def _is_on_cuda(self):
144 | return next(self.parameters()).is_cuda
145 |
146 |
147 | class CNN(dgr.Solver):
148 | def __init__(self,
149 | image_size,
150 | image_channel_size, classes,
151 | depth, channel_size, reducing_layers=3):
152 | # configurations
153 | super().__init__()
154 | self.image_size = image_size
155 | self.image_channel_size = image_channel_size
156 | self.classes = classes
157 | self.depth = depth
158 | self.channel_size = channel_size
159 | self.reducing_layers = reducing_layers
160 |
161 | # layers
162 | self.layers = nn.ModuleList([nn.Conv2d(
163 | self.image_channel_size, self.channel_size//(2**(depth-2)),
164 | 3, 1, 1
165 | )])
166 |
167 | for i in range(self.depth-2):
168 | previous_conv = [
169 | l for l in self.layers if
170 | isinstance(l, nn.Conv2d)
171 | ][-1]
172 | self.layers.append(nn.Conv2d(
173 | previous_conv.out_channels,
174 | previous_conv.out_channels * 2,
175 | 3, 1 if i >= reducing_layers else 2, 1
176 | ))
177 | self.layers.append(nn.BatchNorm2d(previous_conv.out_channels * 2))
178 | self.layers.append(nn.ReLU())
179 |
180 | self.layers.append(utils.LambdaModule(lambda x: x.view(x.size(0), -1)))
181 | self.layers.append(nn.Linear(
182 | (image_size//(2**reducing_layers))**2 * channel_size,
183 | self.classes
184 | ))
185 |
186 | def forward(self, x):
187 | return reduce(lambda x, l: l(x), self.layers, x)
188 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # pytorch
2 | http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp35-cp35m-manylinux1_x86_64.whl
3 | torchvision==0.1.9
4 | torchtext==0.1.1
5 | visdom
6 |
7 | # data
8 | scipy
9 | scikit-learn
10 | numpy
11 | pillow
12 |
13 | # utils (others)
14 | colorama
15 | tqdm
16 | lmdb
17 | requests
18 | fake-useragent
19 |
--------------------------------------------------------------------------------
/run_full_experiments:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # ================
4 | # Permutated MNIST
5 | # ================
6 |
7 | PERM_MNIST_GENERATOR_ITERATIONS=3000
8 | PERM_MNIST_SOLVER_ITERATIONS=1000
9 | PERM_MNIST_LR=0.0001
10 | PERM_MNIST_IMPORTANCE_OF_NEW_TASK=0.3
11 |
12 | ./main.py --train \
13 | --experiment=permutated-mnist \
14 | --replay-mode=exact-replay \
15 | --solver-iterations=$PERM_MNIST_SOLVER_ITERATIONS \
16 | --lr=$PERM_MNIST_LR \
17 | --importance-of-new-task=$PERM_MNIST_IMPORTANCE_OF_NEW_TASK
18 |
19 | ./main.py --train \
20 | --experiment=permutated-mnist \
21 | --replay-mode=generative-replay \
22 | --generator-iterations=$PERM_MNIST_GENERATOR_ITERATIONS \
23 | --solver-iterations=$PERM_MNIST_SOLVER_ITERATIONS \
24 | --lr=$PERM_MNIST_LR \
25 | --importance-of-new-task=$PERM_MNIST_IMPORTANCE_OF_NEW_TASK
26 |
27 | ./main.py --train \
28 | --experiment=permutated-mnist \
29 | --replay-mode=none \
30 | --solver-iterations=$PERM_MNIST_SOLVER_ITERATIONS \
31 | --lr=$PERM_MNIST_LR
32 |
33 | # ==========
34 | # MNIST-SVHN
35 | # ==========
36 |
37 | MNIST_SVHN_GENERATOR_ITERATIONS=20000
38 | MNIST_SVHN_SOLVER_ITERATIONS=4000
39 | MNIST_SVHN_LR=0.00003
40 | MNIST_SVHN_IMPORTANCE_OF_NEW_TASK=0.4
41 |
42 | ./main.py --train \
43 | --experiment=mnist-svhn \
44 | --replay-mode=exact-replay \
45 | --solver-iterations=$MNIST_SVHN_SOLVER_ITERATIONS \
46 | --importance-of-new-task=$MNIST_SVHN_IMPORTANCE_OF_NEW_TASK \
47 | --lr=$MNIST_SVHN_LR
48 |
49 | ./main.py --train \
50 | --experiment=mnist-svhn \
51 | --replay-mode=generative-replay \
52 | --generator-iterations=$MNIST_SVHN_GENERATOR_ITERATIONS \
53 | --solver-iterations=$MNIST_SVHN_SOLVER_ITERATIONS \
54 | --importance-of-new-task=$MNIST_SVHN_IMPORTANCE_OF_NEW_TASK \
55 | --lr=$MNIST_SVHN_LR \
56 | --sample-log
57 |
58 | ./main.py --train \
59 | --experiment=mnist-svhn \
60 | --replay-mode=none \
61 | --generator-iterations=$MNIST_SVHN_GENERATOR_ITERATIONS \
62 | --solver-iterations=$MNIST_SVHN_SOLVER_ITERATIONS \
63 | --lr=$MNIST_SVHN_LR \
64 | --sample-log
65 |
66 | # ==========
67 | # SVHN-MNIST
68 | # ==========
69 |
70 | SVHN_MNIST_GENERATOR_ITERATIONS=20000
71 | SVHN_MNIST_SOLVER_ITERATIONS=4000
72 | SVHN_MNIST_LR=0.00003
73 | SVHN_MNIST_IMPORTANCE_OF_NEW_TASK=0.4
74 |
75 | ./main.py --train \
76 | --experiment=svhn-mnist \
77 | --replay-mode=exact-replay \
78 | --solver-iterations=$SVHN_MNIST_SOLVER_ITERATIONS \
79 | --importance-of-new-task=$SVHN_MNIST_IMPORTANCE_OF_NEW_TASK \
80 | --lr=$SVHN_MNIST_LR
81 |
82 | ./main.py --train \
83 | --experiment=svhn-mnist \
84 | --replay-mode=generative-replay \
85 | --generator-iterations=$SVHN_MNIST_GENERATOR_ITERATIONS \
86 | --solver-iterations=$SVHN_MNIST_SOLVER_ITERATIONS \
87 | --importance-of-new-task=$SVHN_MNIST_IMPORTANCE_OF_NEW_TASK \
88 | --lr=$SVHN_MNIST_LR \
89 | --sample-log
90 |
91 | ./main.py --train \
92 | --experiment=svhn-mnist \
93 | --replay-mode=none \
94 | --generator-iterations=$SVHN_MNIST_GENERATOR_ITERATIONS \
95 | --solver-iterations=$SVHN_MNIST_SOLVER_ITERATIONS \
96 | --lr=$SVHN_MNIST_LR \
97 | --sample-log
98 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import copy
3 | from torch import optim
4 | from torch import nn
5 | import utils
6 | import visual
7 |
8 |
9 | def train(scholar, train_datasets, test_datasets, replay_mode,
10 | generator_lambda=10.,
11 | generator_c_updates_per_g_update=5,
12 | generator_iterations=2000,
13 | solver_iterations=1000,
14 | importance_of_new_task=.5,
15 | batch_size=32,
16 | test_size=1024,
17 | sample_size=36,
18 | lr=1e-03, weight_decay=1e-05,
19 | beta1=.5, beta2=.9,
20 | loss_log_interval=30,
21 | eval_log_interval=50,
22 | image_log_interval=100,
23 | sample_log_interval=300,
24 | sample_log=False,
25 | sample_dir='./samples',
26 | checkpoint_dir='./checkpoints',
27 | collate_fn=None,
28 | cuda=False):
29 | # define solver criterion and generators for the scholar model.
30 | solver_criterion = nn.CrossEntropyLoss()
31 | solver_optimizer = optim.Adam(
32 | scholar.solver.parameters(),
33 | lr=lr, weight_decay=weight_decay, betas=(beta1, beta2),
34 | )
35 | generator_g_optimizer = optim.Adam(
36 | scholar.generator.generator.parameters(),
37 | lr=lr, weight_decay=weight_decay, betas=(beta1, beta2),
38 | )
39 | generator_c_optimizer = optim.Adam(
40 | scholar.generator.critic.parameters(),
41 | lr=lr, weight_decay=weight_decay, betas=(beta1, beta2),
42 | )
43 |
44 | # set the criterion, optimizers, and training configurations for the
45 | # scholar model.
46 | scholar.solver.set_criterion(solver_criterion)
47 | scholar.solver.set_optimizer(solver_optimizer)
48 | scholar.generator.set_lambda(generator_lambda)
49 | scholar.generator.set_generator_optimizer(generator_g_optimizer)
50 | scholar.generator.set_critic_optimizer(generator_c_optimizer)
51 | scholar.generator.set_critic_updates_per_generator_update(
52 | generator_c_updates_per_g_update
53 | )
54 | scholar.train()
55 |
56 | # define the previous scholar who will generate samples of previous tasks.
57 | previous_scholar = None
58 | previous_datasets = None
59 |
60 | for task, train_dataset in enumerate(train_datasets, 1):
61 | # define callbacks for visualizing the training process.
62 | generator_training_callbacks = [_generator_training_callback(
63 | loss_log_interval=loss_log_interval,
64 | image_log_interval=image_log_interval,
65 | sample_log_interval=sample_log_interval,
66 | sample_log=sample_log,
67 | sample_dir=sample_dir,
68 | sample_size=sample_size,
69 | current_task=task,
70 | total_tasks=len(train_datasets),
71 | total_iterations=generator_iterations,
72 | batch_size=batch_size,
73 | replay_mode=replay_mode,
74 | env=scholar.name,
75 | )]
76 | solver_training_callbacks = [_solver_training_callback(
77 | loss_log_interval=loss_log_interval,
78 | eval_log_interval=eval_log_interval,
79 | current_task=task,
80 | total_tasks=len(train_datasets),
81 | total_iterations=solver_iterations,
82 | batch_size=batch_size,
83 | test_size=test_size,
84 | test_datasets=test_datasets,
85 | replay_mode=replay_mode,
86 | cuda=cuda,
87 | collate_fn=collate_fn,
88 | env=scholar.name,
89 | )]
90 |
91 | # train the scholar with generative replay.
92 | scholar.train_with_replay(
93 | train_dataset,
94 | scholar=previous_scholar,
95 | previous_datasets=previous_datasets,
96 | importance_of_new_task=importance_of_new_task,
97 | batch_size=batch_size,
98 | generator_iterations=generator_iterations,
99 | generator_training_callbacks=generator_training_callbacks,
100 | solver_iterations=solver_iterations,
101 | solver_training_callbacks=solver_training_callbacks,
102 | collate_fn=collate_fn,
103 | )
104 |
105 | previous_scholar = (
106 | copy.deepcopy(scholar) if replay_mode == 'generative-replay' else
107 | None
108 | )
109 | previous_datasets = (
110 | train_datasets[:task] if replay_mode == 'exact-replay' else
111 | None
112 | )
113 |
114 | # save the model after the experiment.
115 | print()
116 | utils.save_checkpoint(scholar, checkpoint_dir)
117 | print()
118 | print()
119 |
120 |
121 | def _generator_training_callback(
122 | loss_log_interval,
123 | image_log_interval,
124 | sample_log_interval,
125 | sample_log,
126 | sample_dir,
127 | current_task,
128 | total_tasks,
129 | total_iterations,
130 | batch_size,
131 | sample_size,
132 | replay_mode,
133 | env):
134 |
135 | def cb(generator, progress, batch_index, result):
136 | iteration = (current_task-1)*total_iterations + batch_index
137 | progress.set_description((
138 | ' '
139 | 'task: {task}/{tasks} | '
140 | 'progress: [{trained}/{total}] ({percentage:.0f}%) | '
141 | 'loss => '
142 | 'g: {g_loss:.4} / '
143 | 'w: {w_dist:.4}'
144 | ).format(
145 | task=current_task,
146 | tasks=total_tasks,
147 | trained=batch_size * batch_index,
148 | total=batch_size * total_iterations,
149 | percentage=(100.*batch_index/total_iterations),
150 | g_loss=result['g_loss'],
151 | w_dist=-result['c_loss'],
152 | ))
153 |
154 | # log the losses of the generator.
155 | if iteration % loss_log_interval == 0:
156 | visual.visualize_scalar(
157 | result['g_loss'], 'generator g loss', iteration, env=env
158 | )
159 | visual.visualize_scalar(
160 | -result['c_loss'], 'generator w distance', iteration, env=env
161 | )
162 |
163 | # log the generated images of the generator.
164 | if iteration % image_log_interval == 0:
165 | visual.visualize_images(
166 | generator.sample(sample_size).data,
167 | 'generated samples ({replay_mode})'
168 | .format(replay_mode=replay_mode), env=env,
169 | )
170 |
171 | # log the sample images of the generator
172 | if iteration % sample_log_interval == 0 and sample_log:
173 | utils.test_model(generator, sample_size, os.path.join(
174 | sample_dir,
175 | env + '-sample-logs',
176 | str(iteration)
177 | ), verbose=False)
178 |
179 | return cb
180 |
181 |
182 | def _solver_training_callback(
183 | loss_log_interval,
184 | eval_log_interval,
185 | current_task,
186 | total_tasks,
187 | total_iterations,
188 | batch_size,
189 | test_size,
190 | test_datasets,
191 | cuda,
192 | replay_mode,
193 | collate_fn,
194 | env):
195 |
196 | def cb(solver, progress, batch_index, result):
197 | iteration = (current_task-1)*total_iterations + batch_index
198 | progress.set_description((
199 | ' '
200 | 'task: {task}/{tasks} | '
201 | 'progress: [{trained}/{total}] ({percentage:.0f}%) | '
202 | 'loss: {loss:.4} | '
203 | 'prec: {prec:.4}'
204 | ).format(
205 | task=current_task,
206 | tasks=total_tasks,
207 | trained=batch_size * batch_index,
208 | total=batch_size * total_iterations,
209 | percentage=(100.*batch_index/total_iterations),
210 | loss=result['loss'],
211 | prec=result['precision'],
212 | ))
213 |
214 | # log the loss of the solver.
215 | if iteration % loss_log_interval == 0:
216 | visual.visualize_scalar(
217 | result['loss'], 'solver loss', iteration, env=env
218 | )
219 |
220 | # evaluate the solver on multiple tasks.
221 | if iteration % eval_log_interval == 0:
222 | names = ['task {}'.format(i+1) for i in range(len(test_datasets))]
223 | precs = [
224 | utils.validate(
225 | solver, test_datasets[i], test_size=test_size,
226 | cuda=cuda, verbose=False, collate_fn=collate_fn,
227 | ) if i+1 <= current_task else 0 for i in
228 | range(len(test_datasets))
229 | ]
230 | title = 'precision ({replay_mode})'.format(replay_mode=replay_mode)
231 | visual.visualize_scalars(
232 | precs, names, title,
233 | iteration, env=env
234 | )
235 |
236 | return cb
237 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import torchvision
4 | import torch
5 | from torch import nn
6 | from torch.autograd import Variable
7 | from torch.utils.data import DataLoader
8 | from torch.utils.data.dataloader import default_collate
9 |
10 |
11 | def label_squeezing_collate_fn(batch):
12 | x, y = default_collate(batch)
13 | return x, y.long().squeeze()
14 |
15 |
16 | def get_data_loader(dataset, batch_size, cuda=False, collate_fn=None):
17 | return DataLoader(
18 | dataset, batch_size=batch_size, shuffle=True,
19 | collate_fn=(collate_fn or default_collate),
20 | **({'num_workers': 0, 'pin_memory': True} if cuda else {})
21 | )
22 |
23 |
24 | def save_checkpoint(model, model_dir):
25 | path = os.path.join(model_dir, model.name)
26 |
27 | # save the checkpoint.
28 | if not os.path.exists(model_dir):
29 | os.makedirs(model_dir)
30 |
31 | torch.save({'state': model.state_dict()}, path)
32 |
33 | # notify that we successfully saved the checkpoint.
34 | print('=> saved the model {name} to {path}'.format(
35 | name=model.name, path=path
36 | ))
37 |
38 |
39 | def load_checkpoint(model, model_dir):
40 | path = os.path.join(model_dir, model.name)
41 |
42 | # load the checkpoint.
43 | checkpoint = torch.load(path)
44 | print('=> loaded checkpoint of {name} from {path}'.format(
45 | name=model.name, path=path
46 | ))
47 |
48 | # load parameters and return the checkpoint's epoch and precision.
49 | model.load_state_dict(checkpoint['state'])
50 |
51 |
52 | def test_model(model, sample_size, path, verbose=True):
53 | os.makedirs(os.path.dirname(path), exist_ok=True)
54 | torchvision.utils.save_image(
55 | model.sample(sample_size).data,
56 | path + '.jpg',
57 | nrow=6,
58 | )
59 | if verbose:
60 | print('=> generated sample images at "{}".'.format(path))
61 |
62 |
63 | def validate(model, dataset, test_size=1024,
64 | cuda=False, verbose=True, collate_fn=None):
65 | data_loader = get_data_loader(
66 | dataset, 128, cuda=cuda,
67 | collate_fn=(collate_fn or default_collate),
68 | )
69 | total_tested = 0
70 | total_correct = 0
71 | for data, labels in data_loader:
72 | # break on test size.
73 | if total_tested >= test_size:
74 | break
75 | # test the model.
76 | data = Variable(data).cuda() if cuda else Variable(data)
77 | labels = Variable(labels).cuda() if cuda else Variable(labels)
78 | scores = model(data)
79 | _, predicted = torch.max(scores, 1)
80 |
81 | # update statistics.
82 | total_correct += (predicted == labels).sum().data[0]
83 | total_tested += len(data)
84 |
85 | precision = total_correct / total_tested
86 | if verbose:
87 | print('=> precision: {:.3f}'.format(precision))
88 | return precision
89 |
90 |
91 | def xavier_initialize(model):
92 | modules = [m for n, m in model.named_modules() if 'conv' in n or 'fc' in n]
93 | parameters = [p for m in modules for p in m.parameters()]
94 |
95 | for p in parameters:
96 | if p.dim() >= 2:
97 | nn.init.xavier_normal(p)
98 | else:
99 | nn.init.constant(p, 0)
100 |
101 |
102 | def gaussian_intiailize(model, std=.01):
103 | modules = [m for n, m in model.named_modules() if 'conv' in n or 'fc' in n]
104 | parameters = [p for m in modules for p in m.parameters()]
105 |
106 | for p in parameters:
107 | if p.dim() >= 2:
108 | nn.init.normal(p, std=std)
109 | else:
110 | nn.init.constant(p, 0)
111 |
112 |
113 | class LambdaModule(nn.Module):
114 | def __init__(self, f):
115 | super().__init__()
116 | self.f = f
117 |
118 | def forward(self, x):
119 | return self.f(x)
120 |
--------------------------------------------------------------------------------
/visual.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.cuda import FloatTensor as CUDATensor
3 | from visdom import Visdom
4 |
5 | _WINDOW_CASH = {}
6 |
7 |
8 | def _vis(env='main'):
9 | return Visdom(env=env)
10 |
11 |
12 | def visualize_image(tensor, name, label=None, env='main', w=250, h=250,
13 | update_window_without_label=False):
14 | tensor = tensor.cpu() if isinstance(tensor, CUDATensor) else tensor
15 | title = name + ('-{}'.format(label) if label is not None else '')
16 |
17 | _WINDOW_CASH[title] = _vis(env).image(
18 | tensor.numpy(), win=_WINDOW_CASH.get(title),
19 | opts=dict(title=title, width=w, height=h)
20 | )
21 |
22 | # This is useful when you want to maintain the most recent images.
23 | if update_window_without_label:
24 | _WINDOW_CASH[name] = _vis(env).image(
25 | tensor.numpy(), win=_WINDOW_CASH.get(name),
26 | opts=dict(title=name, width=w, height=h)
27 | )
28 |
29 |
30 | def visualize_images(tensor, name, label=None, env='main', w=400, h=400,
31 | update_window_without_label=False):
32 | tensor = tensor.cpu() if isinstance(tensor, CUDATensor) else tensor
33 | title = name + ('-{}'.format(label) if label is not None else '')
34 |
35 | _WINDOW_CASH[title] = _vis(env).images(
36 | tensor.numpy(), win=_WINDOW_CASH.get(title), nrow=6,
37 | opts=dict(title=title, width=w, height=h)
38 | )
39 |
40 | # This is useful when you want to maintain the most recent images.
41 | if update_window_without_label:
42 | _WINDOW_CASH[name] = _vis(env).images(
43 | tensor.numpy(), win=_WINDOW_CASH.get(name), nrow=6,
44 | opts=dict(title=name, width=w, height=h)
45 | )
46 |
47 |
48 | def visualize_kernel(kernel, name, label=None, env='main', w=250, h=250,
49 | update_window_without_label=False, compress_tensor=False):
50 | # Do not visualize kernels that does not exists.
51 | if kernel is None:
52 | return
53 |
54 | assert len(kernel.size()) in (2, 4)
55 | title = name + ('-{}'.format(label) if label is not None else '')
56 | kernel = kernel.cpu() if isinstance(kernel, CUDATensor) else kernel
57 | kernel_norm = kernel if len(kernel.size()) == 2 else (
58 | (kernel**2).mean(-1).mean(-1) if compress_tensor else
59 | kernel.view(
60 | kernel.size()[0] * kernel.size()[2],
61 | kernel.size()[1] * kernel.size()[3],
62 | )
63 | )
64 | kernel_norm = kernel_norm.abs()
65 |
66 | visualized = (
67 | (kernel_norm - kernel_norm.min()) /
68 | (kernel_norm.max() - kernel_norm.min())
69 | ).numpy()
70 |
71 | _WINDOW_CASH[title] = _vis(env).image(
72 | visualized, win=_WINDOW_CASH.get(title),
73 | opts=dict(title=title, width=w, height=h)
74 | )
75 |
76 | # This is useful when you want to maintain the most recent images.
77 | if update_window_without_label:
78 | _WINDOW_CASH[name] = _vis(env).image(
79 | visualized, win=_WINDOW_CASH.get(name),
80 | opts=dict(title=name, width=w, height=h)
81 | )
82 |
83 |
84 | def visualize_scalar(scalar, name, iteration, env='main'):
85 | visualize_scalars(
86 | [scalar] if isinstance(scalar, float) or len(scalar) == 1 else scalar,
87 | [name], name, iteration, env=env
88 | )
89 |
90 |
91 | def visualize_scalars(scalars, names, title, iteration, env='main'):
92 | assert len(scalars) == len(names)
93 | # Convert scalar tensors to numpy arrays.
94 | scalars, names = list(scalars), list(names)
95 | scalars = [s.cpu() if isinstance(s, CUDATensor) else s for s in scalars]
96 | scalars = [s.numpy() if hasattr(s, 'numpy') else np.array([s]) for s in
97 | scalars]
98 | multi = len(scalars) > 1
99 | num = len(scalars)
100 |
101 | options = dict(
102 | fillarea=True,
103 | legend=names,
104 | width=400,
105 | height=400,
106 | xlabel='Iterations',
107 | ylabel=title,
108 | title=title,
109 | marginleft=30,
110 | marginright=30,
111 | marginbottom=80,
112 | margintop=30,
113 | )
114 |
115 | X = (
116 | np.column_stack(np.array([iteration] * num)) if multi else
117 | np.array([iteration] * num)
118 | )
119 | Y = np.column_stack(scalars) if multi else scalars[0]
120 |
121 | if title in _WINDOW_CASH:
122 | _vis(env).updateTrace(X=X, Y=Y, win=_WINDOW_CASH[title], opts=options)
123 | else:
124 | _WINDOW_CASH[title] = _vis(env).line(X=X, Y=Y, opts=options)
125 |
--------------------------------------------------------------------------------