├── .gitignore
├── LICENSE
├── README.md
├── arts
├── graphic-image.jpg
├── precision-consolidated.png
└── precision-plain.png
├── data.py
├── main.py
├── model.py
├── requirements.txt
├── train.py
├── utils.py
└── visual.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 | datasets/
103 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Ha Junsoo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-ewc
2 | Unofficial PyTorch implementation of DeepMind's paper [Overcoming Catastrophic Forgetting, PNAS 2017](https://arxiv.org/abs/1612.00796).
3 |
4 | 
5 |
6 | ## Results
7 |
8 | Continual Learning **without EWC** (*left*) and **with EWC** (*right*).
9 |
10 |
11 |
12 |
13 | ## Installation
14 | ```
15 | $ git clone https://github.com/kuc2477/pytorch-ewc && cd pytorch-ewc
16 | $ pip install -r requirements.txt
17 | ```
18 |
19 |
20 | ## CLI
21 | Implementation CLI is provided by `main.py`
22 |
23 |
24 | #### Usage
25 | ```
26 | $ ./main.py --help
27 | $ usage: EWC PyTorch Implementation [-h] [--hidden-size HIDDEN_SIZE]
28 | [--hidden-layer-num HIDDEN_LAYER_NUM]
29 | [--hidden-dropout-prob HIDDEN_DROPOUT_PROB]
30 | [--input-dropout-prob INPUT_DROPOUT_PROB]
31 | [--task-number TASK_NUMBER]
32 | [--epochs-per-task EPOCHS_PER_TASK]
33 | [--lamda LAMDA] [--lr LR]
34 | [--weight-decay WEIGHT_DECAY]
35 | [--batch-size BATCH_SIZE]
36 | [--test-size TEST_SIZE]
37 | [--fisher-estimation-sample-size FISHER_ESTIMATION_SAMPLE_SIZE]
38 | [--random-seed RANDOM_SEED] [--no-gpus]
39 | [--eval-log-interval EVAL_LOG_INTERVAL]
40 | [--loss-log-interval LOSS_LOG_INTERVAL]
41 | [--consolidate]
42 |
43 | optional arguments:
44 | -h, --help show this help message and exit
45 | --hidden-size HIDDEN_SIZE
46 | --hidden-layer-num HIDDEN_LAYER_NUM
47 | --hidden-dropout-prob HIDDEN_DROPOUT_PROB
48 | --input-dropout-prob INPUT_DROPOUT_PROB
49 | --task-number TASK_NUMBER
50 | --epochs-per-task EPOCHS_PER_TASK
51 | --lamda LAMDA
52 | --lr LR
53 | --weight-decay WEIGHT_DECAY
54 | --batch-size BATCH_SIZE
55 | --test-size TEST_SIZE
56 | --fisher-estimation-sample-size FISHER_ESTIMATION_SAMPLE_SIZE
57 | --random-seed RANDOM_SEED
58 | --no-gpus
59 | --eval-log-interval EVAL_LOG_INTERVAL
60 | --loss-log-interval LOSS_LOG_INTERVAL
61 | --consolidate
62 |
63 | ```
64 |
65 |
66 | #### Train
67 | ```
68 | $ python -m visdom.server &
69 | $ ./main.py # Train the network without consolidation.
70 | $ ./main.py --consolidate # Train the network with consolidation.
71 | ```
72 |
73 |
74 | ## Update Logs
75 | - 2019.06.29
76 | - **Fixed a critical bug within `model.estimate_fisher()`**: Squared gradients of log-likelihood w.r.t. each layer were mean-reduced over all the dimensions. Now it correctly estimates the Fisher matrix by averaging only over the batch dimension
77 | - 2019.03.22
78 | - **Fixed a critical bug within `model.estimate_fisher()`**: Fisher matrix were being estimated with squared expectation of gradient of log-likelihoods. Now it estimates the Fisher matrix with the expectation of squared gradient of log-likelihood.
79 | - Changed the default optimizer from Adam to SGD
80 | - Migrated the project to PyTorch 1.0.1 and visdom 0.1.8.8
81 |
82 | ## Reference
83 | - [Overcoming Catastrophic Forgetting, PNAS 2017](https://arxiv.org/abs/1612.00796)
84 |
85 | ## Author
86 | Ha Junsoo / [@kuc2477](https://github.com/kuc2477) / MIT License
87 |
--------------------------------------------------------------------------------
/arts/graphic-image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-ewc/4afaa6666d6b4f1a91a110caf69e7b77f049dc08/arts/graphic-image.jpg
--------------------------------------------------------------------------------
/arts/precision-consolidated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-ewc/4afaa6666d6b4f1a91a110caf69e7b77f049dc08/arts/precision-consolidated.png
--------------------------------------------------------------------------------
/arts/precision-plain.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kuc2477/pytorch-ewc/4afaa6666d6b4f1a91a110caf69e7b77f049dc08/arts/precision-plain.png
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | from torchvision import datasets, transforms
2 |
3 |
4 | def _permutate_image_pixels(image, permutation):
5 | if permutation is None:
6 | return image
7 |
8 | c, h, w = image.size()
9 | image = image.view(-1, c)
10 | image = image[permutation, :]
11 | image.view(c, h, w)
12 | return image
13 |
14 |
15 | def get_dataset(name, train=True, download=True, permutation=None):
16 | dataset_class = AVAILABLE_DATASETS[name]
17 | dataset_transform = transforms.Compose([
18 | *AVAILABLE_TRANSFORMS[name],
19 | transforms.Lambda(lambda x: _permutate_image_pixels(x, permutation)),
20 | ])
21 |
22 | return dataset_class(
23 | './datasets/{name}'.format(name=name), train=train,
24 | download=download, transform=dataset_transform,
25 | )
26 |
27 |
28 | AVAILABLE_DATASETS = {
29 | 'mnist': datasets.MNIST
30 | }
31 |
32 | AVAILABLE_TRANSFORMS = {
33 | 'mnist': [
34 | transforms.ToTensor(),
35 | transforms.ToPILImage(),
36 | transforms.Pad(2),
37 | transforms.ToTensor(),
38 | transforms.Normalize((0.1307,), (0.3081,)),
39 | ]
40 | }
41 |
42 | DATASET_CONFIGS = {
43 | 'mnist': {'size': 32, 'channels': 1, 'classes': 10}
44 | }
45 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from argparse import ArgumentParser
3 | import numpy as np
4 | import torch
5 | from data import get_dataset, DATASET_CONFIGS
6 | from train import train
7 | from model import MLP
8 | import utils
9 |
10 |
11 | parser = ArgumentParser('EWC PyTorch Implementation')
12 | parser.add_argument('--hidden-size', type=int, default=400)
13 | parser.add_argument('--hidden-layer-num', type=int, default=2)
14 | parser.add_argument('--hidden-dropout-prob', type=float, default=.5)
15 | parser.add_argument('--input-dropout-prob', type=float, default=.2)
16 |
17 | parser.add_argument('--task-number', type=int, default=8)
18 | parser.add_argument('--epochs-per-task', type=int, default=3)
19 | parser.add_argument('--lamda', type=float, default=40)
20 | parser.add_argument('--lr', type=float, default=1e-1)
21 | parser.add_argument('--weight-decay', type=float, default=0)
22 | parser.add_argument('--batch-size', type=int, default=128)
23 | parser.add_argument('--test-size', type=int, default=1024)
24 | parser.add_argument('--fisher-estimation-sample-size', type=int, default=1024)
25 | parser.add_argument('--random-seed', type=int, default=0)
26 | parser.add_argument('--no-gpus', action='store_false', dest='cuda')
27 | parser.add_argument('--eval-log-interval', type=int, default=250)
28 | parser.add_argument('--loss-log-interval', type=int, default=250)
29 | parser.add_argument('--consolidate', action='store_true')
30 |
31 |
32 | if __name__ == '__main__':
33 | args = parser.parse_args()
34 |
35 | # decide whether to use cuda or not.
36 | cuda = torch.cuda.is_available() and args.cuda
37 |
38 | # generate permutations for the tasks.
39 | np.random.seed(args.random_seed)
40 | permutations = [
41 | np.random.permutation(DATASET_CONFIGS['mnist']['size']**2) for
42 | _ in range(args.task_number)
43 | ]
44 |
45 | # prepare mnist datasets.
46 | train_datasets = [
47 | get_dataset('mnist', permutation=p) for p in permutations
48 | ]
49 | test_datasets = [
50 | get_dataset('mnist', train=False, permutation=p) for p in permutations
51 | ]
52 |
53 | # prepare the model.
54 | mlp = MLP(
55 | DATASET_CONFIGS['mnist']['size']**2,
56 | DATASET_CONFIGS['mnist']['classes'],
57 | hidden_size=args.hidden_size,
58 | hidden_layer_num=args.hidden_layer_num,
59 | hidden_dropout_prob=args.hidden_dropout_prob,
60 | input_dropout_prob=args.input_dropout_prob,
61 | lamda=args.lamda,
62 | )
63 |
64 | # initialize the parameters.
65 | utils.xavier_initialize(mlp)
66 |
67 | # prepare the cuda if needed.
68 | if cuda:
69 | mlp.cuda()
70 |
71 | # run the experiment.
72 | train(
73 | mlp, train_datasets, test_datasets,
74 | epochs_per_task=args.epochs_per_task,
75 | batch_size=args.batch_size,
76 | test_size=args.test_size,
77 | consolidate=args.consolidate,
78 | fisher_estimation_sample_size=args.fisher_estimation_sample_size,
79 | lr=args.lr,
80 | weight_decay=args.weight_decay,
81 | eval_log_interval=args.eval_log_interval,
82 | loss_log_interval=args.loss_log_interval,
83 | cuda=cuda
84 | )
85 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from torch import autograd
6 | from torch.autograd import Variable
7 | import utils
8 |
9 |
10 | class MLP(nn.Module):
11 | def __init__(self, input_size, output_size,
12 | hidden_size=400,
13 | hidden_layer_num=2,
14 | hidden_dropout_prob=.5,
15 | input_dropout_prob=.2,
16 | lamda=40):
17 | # Configurations.
18 | super().__init__()
19 | self.input_size = input_size
20 | self.input_dropout_prob = input_dropout_prob
21 | self.hidden_size = hidden_size
22 | self.hidden_layer_num = hidden_layer_num
23 | self.hidden_dropout_prob = hidden_dropout_prob
24 | self.output_size = output_size
25 | self.lamda = lamda
26 |
27 | # Layers.
28 | self.layers = nn.ModuleList([
29 | # input
30 | nn.Linear(self.input_size, self.hidden_size), nn.ReLU(),
31 | nn.Dropout(self.input_dropout_prob),
32 | # hidden
33 | *((nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU(),
34 | nn.Dropout(self.hidden_dropout_prob)) * self.hidden_layer_num),
35 | # output
36 | nn.Linear(self.hidden_size, self.output_size)
37 | ])
38 |
39 | @property
40 | def name(self):
41 | return (
42 | 'MLP'
43 | '-lambda{lamda}'
44 | '-in{input_size}-out{output_size}'
45 | '-h{hidden_size}x{hidden_layer_num}'
46 | '-dropout_in{input_dropout_prob}_hidden{hidden_dropout_prob}'
47 | ).format(
48 | lamda=self.lamda,
49 | input_size=self.input_size,
50 | output_size=self.output_size,
51 | hidden_size=self.hidden_size,
52 | hidden_layer_num=self.hidden_layer_num,
53 | input_dropout_prob=self.input_dropout_prob,
54 | hidden_dropout_prob=self.hidden_dropout_prob,
55 | )
56 |
57 | def forward(self, x):
58 | return reduce(lambda x, l: l(x), self.layers, x)
59 |
60 | def estimate_fisher(self, dataset, sample_size, batch_size=32):
61 | # sample loglikelihoods from the dataset.
62 | data_loader = utils.get_data_loader(dataset, batch_size)
63 | loglikelihoods = []
64 | for x, y in data_loader:
65 | x = x.view(batch_size, -1)
66 | x = Variable(x).cuda() if self._is_on_cuda() else Variable(x)
67 | y = Variable(y).cuda() if self._is_on_cuda() else Variable(y)
68 | loglikelihoods.append(
69 | F.log_softmax(self(x), dim=1)[range(batch_size), y.data]
70 | )
71 | if len(loglikelihoods) >= sample_size // batch_size:
72 | break
73 | # estimate the fisher information of the parameters.
74 | loglikelihoods = torch.cat(loglikelihoods).unbind()
75 | loglikelihood_grads = zip(*[autograd.grad(
76 | l, self.parameters(),
77 | retain_graph=(i < len(loglikelihoods))
78 | ) for i, l in enumerate(loglikelihoods, 1)])
79 | loglikelihood_grads = [torch.stack(gs) for gs in loglikelihood_grads]
80 | fisher_diagonals = [(g ** 2).mean(0) for g in loglikelihood_grads]
81 | param_names = [
82 | n.replace('.', '__') for n, p in self.named_parameters()
83 | ]
84 | return {n: f.detach() for n, f in zip(param_names, fisher_diagonals)}
85 |
86 | def consolidate(self, fisher):
87 | for n, p in self.named_parameters():
88 | n = n.replace('.', '__')
89 | self.register_buffer('{}_mean'.format(n), p.data.clone())
90 | self.register_buffer('{}_fisher'
91 | .format(n), fisher[n].data.clone())
92 |
93 | def ewc_loss(self, cuda=False):
94 | try:
95 | losses = []
96 | for n, p in self.named_parameters():
97 | # retrieve the consolidated mean and fisher information.
98 | n = n.replace('.', '__')
99 | mean = getattr(self, '{}_mean'.format(n))
100 | fisher = getattr(self, '{}_fisher'.format(n))
101 | # wrap mean and fisher in variables.
102 | mean = Variable(mean)
103 | fisher = Variable(fisher)
104 | # calculate a ewc loss. (assumes the parameter's prior as
105 | # gaussian distribution with the estimated mean and the
106 | # estimated cramer-rao lower bound variance, which is
107 | # equivalent to the inverse of fisher information)
108 | losses.append((fisher * (p-mean)**2).sum())
109 | return (self.lamda/2)*sum(losses)
110 | except AttributeError:
111 | # ewc loss is 0 if there's no consolidated parameters.
112 | return (
113 | Variable(torch.zeros(1)).cuda() if cuda else
114 | Variable(torch.zeros(1))
115 | )
116 |
117 | def _is_on_cuda(self):
118 | return next(self.parameters()).is_cuda
119 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # pytorch
2 | torch==1.0.1.post2
3 | torchvision==0.2.2.post3
4 | torchtext==0.1.1
5 | visdom==0.1.8.8
6 |
7 | # data
8 | scipy
9 | scikit-learn
10 | numpy
11 | pillow
12 |
13 | # utils (debugging)
14 | pdbpp
15 | ipdb
16 | ipython
17 | jupyter
18 | jupyterthemes
19 | jupyter_contrib_nbextensions
20 | jupyter_nbextensions_configurator
21 |
22 | # utils (others)
23 | colorama
24 | tqdm
25 | lmdb
26 | requests
27 | fake-useragent
28 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from torch import optim
2 | from torch import nn
3 | from torch.autograd import Variable
4 | from tqdm import tqdm
5 | from visdom import Visdom
6 | import utils
7 | import visual
8 |
9 |
10 | def train(model, train_datasets, test_datasets, epochs_per_task=10,
11 | batch_size=64, test_size=1024, consolidate=True,
12 | fisher_estimation_sample_size=1024,
13 | lr=1e-3, weight_decay=1e-5,
14 | loss_log_interval=30,
15 | eval_log_interval=50,
16 | cuda=False):
17 | # prepare the loss criteriton and the optimizer.
18 | criteriton = nn.CrossEntropyLoss()
19 | optimizer = optim.SGD(model.parameters(), lr=lr,
20 | weight_decay=weight_decay)
21 |
22 | # instantiate a visdom client
23 | vis = Visdom(env=model.name)
24 |
25 | # set the model's mode to training mode.
26 | model.train()
27 |
28 | for task, train_dataset in enumerate(train_datasets, 1):
29 | for epoch in range(1, epochs_per_task+1):
30 | # prepare the data loaders.
31 | data_loader = utils.get_data_loader(
32 | train_dataset, batch_size=batch_size,
33 | cuda=cuda
34 | )
35 | data_stream = tqdm(enumerate(data_loader, 1))
36 |
37 | for batch_index, (x, y) in data_stream:
38 | # where are we?
39 | data_size = len(x)
40 | dataset_size = len(data_loader.dataset)
41 | dataset_batches = len(data_loader)
42 | previous_task_iteration = sum([
43 | epochs_per_task * len(d) // batch_size for d in
44 | train_datasets[:task-1]
45 | ])
46 | current_task_iteration = (
47 | (epoch-1)*dataset_batches + batch_index
48 | )
49 | iteration = (
50 | previous_task_iteration +
51 | current_task_iteration
52 | )
53 |
54 | # prepare the data.
55 | x = x.view(data_size, -1)
56 | x = Variable(x).cuda() if cuda else Variable(x)
57 | y = Variable(y).cuda() if cuda else Variable(y)
58 |
59 | # run the model and backpropagate the errors.
60 | optimizer.zero_grad()
61 | scores = model(x)
62 | ce_loss = criteriton(scores, y)
63 | ewc_loss = model.ewc_loss(cuda=cuda)
64 | loss = ce_loss + ewc_loss
65 | loss.backward()
66 | optimizer.step()
67 |
68 | # calculate the training precision.
69 | _, predicted = scores.max(1)
70 | precision = (predicted == y).sum().float() / len(x)
71 |
72 | data_stream.set_description((
73 | '=> '
74 | 'task: {task}/{tasks} | '
75 | 'epoch: {epoch}/{epochs} | '
76 | 'progress: [{trained}/{total}] ({progress:.0f}%) | '
77 | 'prec: {prec:.4} | '
78 | 'loss => '
79 | 'ce: {ce_loss:.4} / '
80 | 'ewc: {ewc_loss:.4} / '
81 | 'total: {loss:.4}'
82 | ).format(
83 | task=task,
84 | tasks=len(train_datasets),
85 | epoch=epoch,
86 | epochs=epochs_per_task,
87 | trained=batch_index*batch_size,
88 | total=dataset_size,
89 | progress=(100.*batch_index/dataset_batches),
90 | prec=float(precision),
91 | ce_loss=float(ce_loss),
92 | ewc_loss=float(ewc_loss),
93 | loss=float(loss),
94 | ))
95 |
96 | # Send test precision to the visdom server.
97 | if iteration % eval_log_interval == 0:
98 | names = [
99 | 'task {}'.format(i+1) for i in
100 | range(len(train_datasets))
101 | ]
102 | precs = [
103 | utils.validate(
104 | model, test_datasets[i], test_size=test_size,
105 | cuda=cuda, verbose=False,
106 | ) if i+1 <= task else 0 for i in
107 | range(len(train_datasets))
108 | ]
109 | title = (
110 | 'precision (consolidated)' if consolidate else
111 | 'precision'
112 | )
113 | visual.visualize_scalars(
114 | vis, precs, names, title,
115 | iteration
116 | )
117 |
118 | # Send losses to the visdom server.
119 | if iteration % loss_log_interval == 0:
120 | title = 'loss (consolidated)' if consolidate else 'loss'
121 | visual.visualize_scalars(
122 | vis,
123 | [loss, ce_loss, ewc_loss],
124 | ['total', 'cross entropy', 'ewc'],
125 | title, iteration
126 | )
127 |
128 | if consolidate and task < len(train_datasets):
129 | # estimate the fisher information of the parameters and consolidate
130 | # them in the network.
131 | print(
132 | '=> Estimating diagonals of the fisher information matrix...',
133 | flush=True, end='',
134 | )
135 | model.consolidate(model.estimate_fisher(
136 | train_dataset, fisher_estimation_sample_size
137 | ))
138 | print(' Done!')
139 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import shutil
4 | import torch
5 | from torch.autograd import Variable
6 | from torch.utils.data import DataLoader
7 | from torch.utils.data.dataloader import default_collate
8 | from torch.nn import init
9 |
10 |
11 | def get_data_loader(dataset, batch_size, cuda=False, collate_fn=None):
12 |
13 | return DataLoader(
14 | dataset, batch_size=batch_size,
15 | shuffle=True, collate_fn=(collate_fn or default_collate),
16 | **({'num_workers': 2, 'pin_memory': True} if cuda else {})
17 | )
18 |
19 |
20 | def save_checkpoint(model, model_dir, epoch, precision, best=True):
21 | path = os.path.join(model_dir, model.name)
22 | path_best = os.path.join(model_dir, '{}-best'.format(model.name))
23 |
24 | # save the checkpoint.
25 | if not os.path.exists(model_dir):
26 | os.makedirs(model_dir)
27 | torch.save({
28 | 'state': model.state_dict(),
29 | 'epoch': epoch,
30 | 'precision': precision,
31 | }, path)
32 |
33 | # override the best model if it's the best.
34 | if best:
35 | shutil.copy(path, path_best)
36 | print('=> updated the best model of {name} at {path}'.format(
37 | name=model.name, path=path_best
38 | ))
39 |
40 | # notify that we successfully saved the checkpoint.
41 | print('=> saved the model {name} to {path}'.format(
42 | name=model.name, path=path
43 | ))
44 |
45 |
46 | def load_checkpoint(model, model_dir, best=True):
47 | path = os.path.join(model_dir, model.name)
48 | path_best = os.path.join(model_dir, '{}-best'.format(model.name))
49 |
50 | # load the checkpoint.
51 | checkpoint = torch.load(path_best if best else path)
52 | print('=> loaded checkpoint of {name} from {path}'.format(
53 | name=model.name, path=(path_best if best else path)
54 | ))
55 |
56 | # load parameters and return the checkpoint's epoch and precision.
57 | model.load_state_dict(checkpoint['state'])
58 | epoch = checkpoint['epoch']
59 | precision = checkpoint['precision']
60 | return epoch, precision
61 |
62 |
63 | def validate(model, dataset, test_size=256, batch_size=32,
64 | cuda=False, verbose=True):
65 | mode = model.training
66 | model.train(mode=False)
67 | data_loader = get_data_loader(dataset, batch_size, cuda=cuda)
68 | total_tested = 0
69 | total_correct = 0
70 | for x, y in data_loader:
71 | # break on test size.
72 | if total_tested >= test_size:
73 | break
74 | # test the model.
75 | x = x.view(batch_size, -1)
76 | x = Variable(x).cuda() if cuda else Variable(x)
77 | y = Variable(y).cuda() if cuda else Variable(y)
78 | scores = model(x)
79 | _, predicted = scores.max(1)
80 | # update statistics.
81 | total_correct += int((predicted == y).sum())
82 | total_tested += len(x)
83 | model.train(mode=mode)
84 | precision = total_correct / total_tested
85 | if verbose:
86 | print('=> precision: {:.3f}'.format(precision))
87 | return precision
88 |
89 |
90 | def xavier_initialize(model):
91 | modules = [
92 | m for n, m in model.named_modules() if
93 | 'conv' in n or 'linear' in n
94 | ]
95 |
96 | parameters = [
97 | p for
98 | m in modules for
99 | p in m.parameters() if
100 | p.dim() >= 2
101 | ]
102 |
103 | for p in parameters:
104 | init.xavier_normal(p)
105 |
106 |
107 | def gaussian_intiailize(model, std=.1):
108 | for p in model.parameters():
109 | init.normal(p, std=std)
110 |
--------------------------------------------------------------------------------
/visual.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.cuda import FloatTensor as CUDATensor
3 |
4 | _WINDOW_CASH = {}
5 |
6 |
7 | def visualize_image(vis, tensor, name, label=None, w=250, h=250,
8 | update_window_without_label=False):
9 | tensor = tensor.cpu() if isinstance(tensor, CUDATensor) else tensor
10 | title = name + ('-{}'.format(label) if label is not None else '')
11 |
12 | _WINDOW_CASH[title] = vis.image(
13 | tensor.numpy(), win=_WINDOW_CASH.get(title),
14 | opts=dict(title=title, width=w, height=h)
15 | )
16 |
17 | # This is useful when you want to maintain the most recent images.
18 | if update_window_without_label:
19 | _WINDOW_CASH[name] = vis.image(
20 | tensor.numpy(), win=_WINDOW_CASH.get(name),
21 | opts=dict(title=name, width=w, height=h)
22 | )
23 |
24 |
25 | def visualize_images(vis, tensor, name, label=None, w=250, h=250,
26 | update_window_without_label=False):
27 | tensor = tensor.cpu() if isinstance(tensor, CUDATensor) else tensor
28 | title = name + ('-{}'.format(label) if label is not None else '')
29 |
30 | _WINDOW_CASH[title] = vis.images(
31 | tensor.numpy(), win=_WINDOW_CASH.get(title),
32 | opts=dict(title=title, width=w, height=h)
33 | )
34 |
35 | # This is useful when you want to maintain the most recent images.
36 | if update_window_without_label:
37 | _WINDOW_CASH[name] = vis.images(
38 | tensor.numpy(), win=_WINDOW_CASH.get(name),
39 | opts=dict(title=name, width=w, height=h)
40 | )
41 |
42 |
43 | def visualize_kernel(vis, kernel, name, label=None, w=250, h=250,
44 | update_window_without_label=False, compress_tensor=False):
45 | # Do not visualize kernels that does not exists.
46 | if kernel is None:
47 | return
48 |
49 | assert len(kernel.size()) in (2, 4)
50 | title = name + ('-{}'.format(label) if label is not None else '')
51 | kernel = kernel.cpu() if isinstance(kernel, CUDATensor) else kernel
52 | kernel_norm = kernel if len(kernel.size()) == 2 else (
53 | (kernel**2).mean(-1).mean(-1) if compress_tensor else
54 | kernel.view(
55 | kernel.size()[0] * kernel.size()[2],
56 | kernel.size()[1] * kernel.size()[3],
57 | )
58 | )
59 | kernel_norm = kernel_norm.abs()
60 |
61 | visualized = (
62 | (kernel_norm - kernel_norm.min()) /
63 | (kernel_norm.max() - kernel_norm.min())
64 | ).numpy()
65 |
66 | _WINDOW_CASH[title] = vis.image(
67 | visualized, win=_WINDOW_CASH.get(title),
68 | opts=dict(title=title, width=w, height=h)
69 | )
70 |
71 | # This is useful when you want to maintain the most recent images.
72 | if update_window_without_label:
73 | _WINDOW_CASH[name] = vis.image(
74 | visualized, win=_WINDOW_CASH.get(name),
75 | opts=dict(title=name, width=w, height=h)
76 | )
77 |
78 |
79 | def visualize_scalar(vis, scalar, name, iteration):
80 | visualize_scalars(
81 | vis,
82 | [scalar] if isinstance(scalar, float) or len(scalar) == 1 else scalar,
83 | [name], name, iteration
84 | )
85 |
86 |
87 | def visualize_scalars(vis, scalars, names, title, iteration):
88 | assert len(scalars) == len(names)
89 | # Convert scalar tensors to numpy arrays.
90 | scalars, names = list(scalars), list(names)
91 | scalars = [s.cpu() if isinstance(s, CUDATensor) else s for s in scalars]
92 | scalars = [s.detach().numpy() if hasattr(s, 'numpy') else
93 | np.array([s]) for s in scalars]
94 | multi = len(scalars) > 1
95 | num = len(scalars)
96 |
97 | options = dict(
98 | fillarea=True,
99 | legend=names,
100 | width=400,
101 | height=400,
102 | xlabel='Iterations',
103 | ylabel=title,
104 | title=title,
105 | marginleft=30,
106 | marginright=30,
107 | marginbottom=80,
108 | margintop=30,
109 | )
110 |
111 | X = (
112 | np.column_stack(np.array([iteration] * num)) if multi else
113 | np.array([iteration] * num)
114 | )
115 | Y = np.column_stack(scalars) if multi else scalars[0]
116 |
117 | if title in _WINDOW_CASH:
118 | vis.line(
119 | X=X, Y=Y, win=_WINDOW_CASH[title], opts=options, update='append'
120 | )
121 | else:
122 | _WINDOW_CASH[title] = vis.line(X=X, Y=Y, opts=options)
123 |
--------------------------------------------------------------------------------