├── .gitignore ├── LICENSE ├── README.md ├── arts ├── graphic-image.jpg ├── precision-consolidated.png └── precision-plain.png ├── data.py ├── main.py ├── model.py ├── requirements.txt ├── train.py ├── utils.py └── visual.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | datasets/ 103 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ha Junsoo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-ewc 2 | Unofficial PyTorch implementation of DeepMind's paper [Overcoming Catastrophic Forgetting, PNAS 2017](https://arxiv.org/abs/1612.00796). 3 | 4 | ![graphic-image](./arts/graphic-image.jpg) 5 | 6 | ## Results 7 | 8 | Continual Learning **without EWC** (*left*) and **with EWC** (*right*). 9 | 10 | 11 | 12 | 13 | ## Installation 14 | ``` 15 | $ git clone https://github.com/kuc2477/pytorch-ewc && cd pytorch-ewc 16 | $ pip install -r requirements.txt 17 | ``` 18 | 19 | 20 | ## CLI 21 | Implementation CLI is provided by `main.py` 22 | 23 | 24 | #### Usage 25 | ``` 26 | $ ./main.py --help 27 | $ usage: EWC PyTorch Implementation [-h] [--hidden-size HIDDEN_SIZE] 28 | [--hidden-layer-num HIDDEN_LAYER_NUM] 29 | [--hidden-dropout-prob HIDDEN_DROPOUT_PROB] 30 | [--input-dropout-prob INPUT_DROPOUT_PROB] 31 | [--task-number TASK_NUMBER] 32 | [--epochs-per-task EPOCHS_PER_TASK] 33 | [--lamda LAMDA] [--lr LR] 34 | [--weight-decay WEIGHT_DECAY] 35 | [--batch-size BATCH_SIZE] 36 | [--test-size TEST_SIZE] 37 | [--fisher-estimation-sample-size FISHER_ESTIMATION_SAMPLE_SIZE] 38 | [--random-seed RANDOM_SEED] [--no-gpus] 39 | [--eval-log-interval EVAL_LOG_INTERVAL] 40 | [--loss-log-interval LOSS_LOG_INTERVAL] 41 | [--consolidate] 42 | 43 | optional arguments: 44 | -h, --help show this help message and exit 45 | --hidden-size HIDDEN_SIZE 46 | --hidden-layer-num HIDDEN_LAYER_NUM 47 | --hidden-dropout-prob HIDDEN_DROPOUT_PROB 48 | --input-dropout-prob INPUT_DROPOUT_PROB 49 | --task-number TASK_NUMBER 50 | --epochs-per-task EPOCHS_PER_TASK 51 | --lamda LAMDA 52 | --lr LR 53 | --weight-decay WEIGHT_DECAY 54 | --batch-size BATCH_SIZE 55 | --test-size TEST_SIZE 56 | --fisher-estimation-sample-size FISHER_ESTIMATION_SAMPLE_SIZE 57 | --random-seed RANDOM_SEED 58 | --no-gpus 59 | --eval-log-interval EVAL_LOG_INTERVAL 60 | --loss-log-interval LOSS_LOG_INTERVAL 61 | --consolidate 62 | 63 | ``` 64 | 65 | 66 | #### Train 67 | ``` 68 | $ python -m visdom.server & 69 | $ ./main.py # Train the network without consolidation. 70 | $ ./main.py --consolidate # Train the network with consolidation. 71 | ``` 72 | 73 | 74 | ## Update Logs 75 | - 2019.06.29 76 | - **Fixed a critical bug within `model.estimate_fisher()`**: Squared gradients of log-likelihood w.r.t. each layer were mean-reduced over all the dimensions. Now it correctly estimates the Fisher matrix by averaging only over the batch dimension 77 | - 2019.03.22 78 | - **Fixed a critical bug within `model.estimate_fisher()`**: Fisher matrix were being estimated with squared expectation of gradient of log-likelihoods. Now it estimates the Fisher matrix with the expectation of squared gradient of log-likelihood. 79 | - Changed the default optimizer from Adam to SGD 80 | - Migrated the project to PyTorch 1.0.1 and visdom 0.1.8.8 81 | 82 | ## Reference 83 | - [Overcoming Catastrophic Forgetting, PNAS 2017](https://arxiv.org/abs/1612.00796) 84 | 85 | ## Author 86 | Ha Junsoo / [@kuc2477](https://github.com/kuc2477) / MIT License 87 | -------------------------------------------------------------------------------- /arts/graphic-image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-ewc/4afaa6666d6b4f1a91a110caf69e7b77f049dc08/arts/graphic-image.jpg -------------------------------------------------------------------------------- /arts/precision-consolidated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-ewc/4afaa6666d6b4f1a91a110caf69e7b77f049dc08/arts/precision-consolidated.png -------------------------------------------------------------------------------- /arts/precision-plain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-ewc/4afaa6666d6b4f1a91a110caf69e7b77f049dc08/arts/precision-plain.png -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | 3 | 4 | def _permutate_image_pixels(image, permutation): 5 | if permutation is None: 6 | return image 7 | 8 | c, h, w = image.size() 9 | image = image.view(-1, c) 10 | image = image[permutation, :] 11 | image.view(c, h, w) 12 | return image 13 | 14 | 15 | def get_dataset(name, train=True, download=True, permutation=None): 16 | dataset_class = AVAILABLE_DATASETS[name] 17 | dataset_transform = transforms.Compose([ 18 | *AVAILABLE_TRANSFORMS[name], 19 | transforms.Lambda(lambda x: _permutate_image_pixels(x, permutation)), 20 | ]) 21 | 22 | return dataset_class( 23 | './datasets/{name}'.format(name=name), train=train, 24 | download=download, transform=dataset_transform, 25 | ) 26 | 27 | 28 | AVAILABLE_DATASETS = { 29 | 'mnist': datasets.MNIST 30 | } 31 | 32 | AVAILABLE_TRANSFORMS = { 33 | 'mnist': [ 34 | transforms.ToTensor(), 35 | transforms.ToPILImage(), 36 | transforms.Pad(2), 37 | transforms.ToTensor(), 38 | transforms.Normalize((0.1307,), (0.3081,)), 39 | ] 40 | } 41 | 42 | DATASET_CONFIGS = { 43 | 'mnist': {'size': 32, 'channels': 1, 'classes': 10} 44 | } 45 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from argparse import ArgumentParser 3 | import numpy as np 4 | import torch 5 | from data import get_dataset, DATASET_CONFIGS 6 | from train import train 7 | from model import MLP 8 | import utils 9 | 10 | 11 | parser = ArgumentParser('EWC PyTorch Implementation') 12 | parser.add_argument('--hidden-size', type=int, default=400) 13 | parser.add_argument('--hidden-layer-num', type=int, default=2) 14 | parser.add_argument('--hidden-dropout-prob', type=float, default=.5) 15 | parser.add_argument('--input-dropout-prob', type=float, default=.2) 16 | 17 | parser.add_argument('--task-number', type=int, default=8) 18 | parser.add_argument('--epochs-per-task', type=int, default=3) 19 | parser.add_argument('--lamda', type=float, default=40) 20 | parser.add_argument('--lr', type=float, default=1e-1) 21 | parser.add_argument('--weight-decay', type=float, default=0) 22 | parser.add_argument('--batch-size', type=int, default=128) 23 | parser.add_argument('--test-size', type=int, default=1024) 24 | parser.add_argument('--fisher-estimation-sample-size', type=int, default=1024) 25 | parser.add_argument('--random-seed', type=int, default=0) 26 | parser.add_argument('--no-gpus', action='store_false', dest='cuda') 27 | parser.add_argument('--eval-log-interval', type=int, default=250) 28 | parser.add_argument('--loss-log-interval', type=int, default=250) 29 | parser.add_argument('--consolidate', action='store_true') 30 | 31 | 32 | if __name__ == '__main__': 33 | args = parser.parse_args() 34 | 35 | # decide whether to use cuda or not. 36 | cuda = torch.cuda.is_available() and args.cuda 37 | 38 | # generate permutations for the tasks. 39 | np.random.seed(args.random_seed) 40 | permutations = [ 41 | np.random.permutation(DATASET_CONFIGS['mnist']['size']**2) for 42 | _ in range(args.task_number) 43 | ] 44 | 45 | # prepare mnist datasets. 46 | train_datasets = [ 47 | get_dataset('mnist', permutation=p) for p in permutations 48 | ] 49 | test_datasets = [ 50 | get_dataset('mnist', train=False, permutation=p) for p in permutations 51 | ] 52 | 53 | # prepare the model. 54 | mlp = MLP( 55 | DATASET_CONFIGS['mnist']['size']**2, 56 | DATASET_CONFIGS['mnist']['classes'], 57 | hidden_size=args.hidden_size, 58 | hidden_layer_num=args.hidden_layer_num, 59 | hidden_dropout_prob=args.hidden_dropout_prob, 60 | input_dropout_prob=args.input_dropout_prob, 61 | lamda=args.lamda, 62 | ) 63 | 64 | # initialize the parameters. 65 | utils.xavier_initialize(mlp) 66 | 67 | # prepare the cuda if needed. 68 | if cuda: 69 | mlp.cuda() 70 | 71 | # run the experiment. 72 | train( 73 | mlp, train_datasets, test_datasets, 74 | epochs_per_task=args.epochs_per_task, 75 | batch_size=args.batch_size, 76 | test_size=args.test_size, 77 | consolidate=args.consolidate, 78 | fisher_estimation_sample_size=args.fisher_estimation_sample_size, 79 | lr=args.lr, 80 | weight_decay=args.weight_decay, 81 | eval_log_interval=args.eval_log_interval, 82 | loss_log_interval=args.loss_log_interval, 83 | cuda=cuda 84 | ) 85 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch import autograd 6 | from torch.autograd import Variable 7 | import utils 8 | 9 | 10 | class MLP(nn.Module): 11 | def __init__(self, input_size, output_size, 12 | hidden_size=400, 13 | hidden_layer_num=2, 14 | hidden_dropout_prob=.5, 15 | input_dropout_prob=.2, 16 | lamda=40): 17 | # Configurations. 18 | super().__init__() 19 | self.input_size = input_size 20 | self.input_dropout_prob = input_dropout_prob 21 | self.hidden_size = hidden_size 22 | self.hidden_layer_num = hidden_layer_num 23 | self.hidden_dropout_prob = hidden_dropout_prob 24 | self.output_size = output_size 25 | self.lamda = lamda 26 | 27 | # Layers. 28 | self.layers = nn.ModuleList([ 29 | # input 30 | nn.Linear(self.input_size, self.hidden_size), nn.ReLU(), 31 | nn.Dropout(self.input_dropout_prob), 32 | # hidden 33 | *((nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU(), 34 | nn.Dropout(self.hidden_dropout_prob)) * self.hidden_layer_num), 35 | # output 36 | nn.Linear(self.hidden_size, self.output_size) 37 | ]) 38 | 39 | @property 40 | def name(self): 41 | return ( 42 | 'MLP' 43 | '-lambda{lamda}' 44 | '-in{input_size}-out{output_size}' 45 | '-h{hidden_size}x{hidden_layer_num}' 46 | '-dropout_in{input_dropout_prob}_hidden{hidden_dropout_prob}' 47 | ).format( 48 | lamda=self.lamda, 49 | input_size=self.input_size, 50 | output_size=self.output_size, 51 | hidden_size=self.hidden_size, 52 | hidden_layer_num=self.hidden_layer_num, 53 | input_dropout_prob=self.input_dropout_prob, 54 | hidden_dropout_prob=self.hidden_dropout_prob, 55 | ) 56 | 57 | def forward(self, x): 58 | return reduce(lambda x, l: l(x), self.layers, x) 59 | 60 | def estimate_fisher(self, dataset, sample_size, batch_size=32): 61 | # sample loglikelihoods from the dataset. 62 | data_loader = utils.get_data_loader(dataset, batch_size) 63 | loglikelihoods = [] 64 | for x, y in data_loader: 65 | x = x.view(batch_size, -1) 66 | x = Variable(x).cuda() if self._is_on_cuda() else Variable(x) 67 | y = Variable(y).cuda() if self._is_on_cuda() else Variable(y) 68 | loglikelihoods.append( 69 | F.log_softmax(self(x), dim=1)[range(batch_size), y.data] 70 | ) 71 | if len(loglikelihoods) >= sample_size // batch_size: 72 | break 73 | # estimate the fisher information of the parameters. 74 | loglikelihoods = torch.cat(loglikelihoods).unbind() 75 | loglikelihood_grads = zip(*[autograd.grad( 76 | l, self.parameters(), 77 | retain_graph=(i < len(loglikelihoods)) 78 | ) for i, l in enumerate(loglikelihoods, 1)]) 79 | loglikelihood_grads = [torch.stack(gs) for gs in loglikelihood_grads] 80 | fisher_diagonals = [(g ** 2).mean(0) for g in loglikelihood_grads] 81 | param_names = [ 82 | n.replace('.', '__') for n, p in self.named_parameters() 83 | ] 84 | return {n: f.detach() for n, f in zip(param_names, fisher_diagonals)} 85 | 86 | def consolidate(self, fisher): 87 | for n, p in self.named_parameters(): 88 | n = n.replace('.', '__') 89 | self.register_buffer('{}_mean'.format(n), p.data.clone()) 90 | self.register_buffer('{}_fisher' 91 | .format(n), fisher[n].data.clone()) 92 | 93 | def ewc_loss(self, cuda=False): 94 | try: 95 | losses = [] 96 | for n, p in self.named_parameters(): 97 | # retrieve the consolidated mean and fisher information. 98 | n = n.replace('.', '__') 99 | mean = getattr(self, '{}_mean'.format(n)) 100 | fisher = getattr(self, '{}_fisher'.format(n)) 101 | # wrap mean and fisher in variables. 102 | mean = Variable(mean) 103 | fisher = Variable(fisher) 104 | # calculate a ewc loss. (assumes the parameter's prior as 105 | # gaussian distribution with the estimated mean and the 106 | # estimated cramer-rao lower bound variance, which is 107 | # equivalent to the inverse of fisher information) 108 | losses.append((fisher * (p-mean)**2).sum()) 109 | return (self.lamda/2)*sum(losses) 110 | except AttributeError: 111 | # ewc loss is 0 if there's no consolidated parameters. 112 | return ( 113 | Variable(torch.zeros(1)).cuda() if cuda else 114 | Variable(torch.zeros(1)) 115 | ) 116 | 117 | def _is_on_cuda(self): 118 | return next(self.parameters()).is_cuda 119 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # pytorch 2 | torch==1.0.1.post2 3 | torchvision==0.2.2.post3 4 | torchtext==0.1.1 5 | visdom==0.1.8.8 6 | 7 | # data 8 | scipy 9 | scikit-learn 10 | numpy 11 | pillow 12 | 13 | # utils (debugging) 14 | pdbpp 15 | ipdb 16 | ipython 17 | jupyter 18 | jupyterthemes 19 | jupyter_contrib_nbextensions 20 | jupyter_nbextensions_configurator 21 | 22 | # utils (others) 23 | colorama 24 | tqdm 25 | lmdb 26 | requests 27 | fake-useragent 28 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from tqdm import tqdm 5 | from visdom import Visdom 6 | import utils 7 | import visual 8 | 9 | 10 | def train(model, train_datasets, test_datasets, epochs_per_task=10, 11 | batch_size=64, test_size=1024, consolidate=True, 12 | fisher_estimation_sample_size=1024, 13 | lr=1e-3, weight_decay=1e-5, 14 | loss_log_interval=30, 15 | eval_log_interval=50, 16 | cuda=False): 17 | # prepare the loss criteriton and the optimizer. 18 | criteriton = nn.CrossEntropyLoss() 19 | optimizer = optim.SGD(model.parameters(), lr=lr, 20 | weight_decay=weight_decay) 21 | 22 | # instantiate a visdom client 23 | vis = Visdom(env=model.name) 24 | 25 | # set the model's mode to training mode. 26 | model.train() 27 | 28 | for task, train_dataset in enumerate(train_datasets, 1): 29 | for epoch in range(1, epochs_per_task+1): 30 | # prepare the data loaders. 31 | data_loader = utils.get_data_loader( 32 | train_dataset, batch_size=batch_size, 33 | cuda=cuda 34 | ) 35 | data_stream = tqdm(enumerate(data_loader, 1)) 36 | 37 | for batch_index, (x, y) in data_stream: 38 | # where are we? 39 | data_size = len(x) 40 | dataset_size = len(data_loader.dataset) 41 | dataset_batches = len(data_loader) 42 | previous_task_iteration = sum([ 43 | epochs_per_task * len(d) // batch_size for d in 44 | train_datasets[:task-1] 45 | ]) 46 | current_task_iteration = ( 47 | (epoch-1)*dataset_batches + batch_index 48 | ) 49 | iteration = ( 50 | previous_task_iteration + 51 | current_task_iteration 52 | ) 53 | 54 | # prepare the data. 55 | x = x.view(data_size, -1) 56 | x = Variable(x).cuda() if cuda else Variable(x) 57 | y = Variable(y).cuda() if cuda else Variable(y) 58 | 59 | # run the model and backpropagate the errors. 60 | optimizer.zero_grad() 61 | scores = model(x) 62 | ce_loss = criteriton(scores, y) 63 | ewc_loss = model.ewc_loss(cuda=cuda) 64 | loss = ce_loss + ewc_loss 65 | loss.backward() 66 | optimizer.step() 67 | 68 | # calculate the training precision. 69 | _, predicted = scores.max(1) 70 | precision = (predicted == y).sum().float() / len(x) 71 | 72 | data_stream.set_description(( 73 | '=> ' 74 | 'task: {task}/{tasks} | ' 75 | 'epoch: {epoch}/{epochs} | ' 76 | 'progress: [{trained}/{total}] ({progress:.0f}%) | ' 77 | 'prec: {prec:.4} | ' 78 | 'loss => ' 79 | 'ce: {ce_loss:.4} / ' 80 | 'ewc: {ewc_loss:.4} / ' 81 | 'total: {loss:.4}' 82 | ).format( 83 | task=task, 84 | tasks=len(train_datasets), 85 | epoch=epoch, 86 | epochs=epochs_per_task, 87 | trained=batch_index*batch_size, 88 | total=dataset_size, 89 | progress=(100.*batch_index/dataset_batches), 90 | prec=float(precision), 91 | ce_loss=float(ce_loss), 92 | ewc_loss=float(ewc_loss), 93 | loss=float(loss), 94 | )) 95 | 96 | # Send test precision to the visdom server. 97 | if iteration % eval_log_interval == 0: 98 | names = [ 99 | 'task {}'.format(i+1) for i in 100 | range(len(train_datasets)) 101 | ] 102 | precs = [ 103 | utils.validate( 104 | model, test_datasets[i], test_size=test_size, 105 | cuda=cuda, verbose=False, 106 | ) if i+1 <= task else 0 for i in 107 | range(len(train_datasets)) 108 | ] 109 | title = ( 110 | 'precision (consolidated)' if consolidate else 111 | 'precision' 112 | ) 113 | visual.visualize_scalars( 114 | vis, precs, names, title, 115 | iteration 116 | ) 117 | 118 | # Send losses to the visdom server. 119 | if iteration % loss_log_interval == 0: 120 | title = 'loss (consolidated)' if consolidate else 'loss' 121 | visual.visualize_scalars( 122 | vis, 123 | [loss, ce_loss, ewc_loss], 124 | ['total', 'cross entropy', 'ewc'], 125 | title, iteration 126 | ) 127 | 128 | if consolidate and task < len(train_datasets): 129 | # estimate the fisher information of the parameters and consolidate 130 | # them in the network. 131 | print( 132 | '=> Estimating diagonals of the fisher information matrix...', 133 | flush=True, end='', 134 | ) 135 | model.consolidate(model.estimate_fisher( 136 | train_dataset, fisher_estimation_sample_size 137 | )) 138 | print(' Done!') 139 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import shutil 4 | import torch 5 | from torch.autograd import Variable 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data.dataloader import default_collate 8 | from torch.nn import init 9 | 10 | 11 | def get_data_loader(dataset, batch_size, cuda=False, collate_fn=None): 12 | 13 | return DataLoader( 14 | dataset, batch_size=batch_size, 15 | shuffle=True, collate_fn=(collate_fn or default_collate), 16 | **({'num_workers': 2, 'pin_memory': True} if cuda else {}) 17 | ) 18 | 19 | 20 | def save_checkpoint(model, model_dir, epoch, precision, best=True): 21 | path = os.path.join(model_dir, model.name) 22 | path_best = os.path.join(model_dir, '{}-best'.format(model.name)) 23 | 24 | # save the checkpoint. 25 | if not os.path.exists(model_dir): 26 | os.makedirs(model_dir) 27 | torch.save({ 28 | 'state': model.state_dict(), 29 | 'epoch': epoch, 30 | 'precision': precision, 31 | }, path) 32 | 33 | # override the best model if it's the best. 34 | if best: 35 | shutil.copy(path, path_best) 36 | print('=> updated the best model of {name} at {path}'.format( 37 | name=model.name, path=path_best 38 | )) 39 | 40 | # notify that we successfully saved the checkpoint. 41 | print('=> saved the model {name} to {path}'.format( 42 | name=model.name, path=path 43 | )) 44 | 45 | 46 | def load_checkpoint(model, model_dir, best=True): 47 | path = os.path.join(model_dir, model.name) 48 | path_best = os.path.join(model_dir, '{}-best'.format(model.name)) 49 | 50 | # load the checkpoint. 51 | checkpoint = torch.load(path_best if best else path) 52 | print('=> loaded checkpoint of {name} from {path}'.format( 53 | name=model.name, path=(path_best if best else path) 54 | )) 55 | 56 | # load parameters and return the checkpoint's epoch and precision. 57 | model.load_state_dict(checkpoint['state']) 58 | epoch = checkpoint['epoch'] 59 | precision = checkpoint['precision'] 60 | return epoch, precision 61 | 62 | 63 | def validate(model, dataset, test_size=256, batch_size=32, 64 | cuda=False, verbose=True): 65 | mode = model.training 66 | model.train(mode=False) 67 | data_loader = get_data_loader(dataset, batch_size, cuda=cuda) 68 | total_tested = 0 69 | total_correct = 0 70 | for x, y in data_loader: 71 | # break on test size. 72 | if total_tested >= test_size: 73 | break 74 | # test the model. 75 | x = x.view(batch_size, -1) 76 | x = Variable(x).cuda() if cuda else Variable(x) 77 | y = Variable(y).cuda() if cuda else Variable(y) 78 | scores = model(x) 79 | _, predicted = scores.max(1) 80 | # update statistics. 81 | total_correct += int((predicted == y).sum()) 82 | total_tested += len(x) 83 | model.train(mode=mode) 84 | precision = total_correct / total_tested 85 | if verbose: 86 | print('=> precision: {:.3f}'.format(precision)) 87 | return precision 88 | 89 | 90 | def xavier_initialize(model): 91 | modules = [ 92 | m for n, m in model.named_modules() if 93 | 'conv' in n or 'linear' in n 94 | ] 95 | 96 | parameters = [ 97 | p for 98 | m in modules for 99 | p in m.parameters() if 100 | p.dim() >= 2 101 | ] 102 | 103 | for p in parameters: 104 | init.xavier_normal(p) 105 | 106 | 107 | def gaussian_intiailize(model, std=.1): 108 | for p in model.parameters(): 109 | init.normal(p, std=std) 110 | -------------------------------------------------------------------------------- /visual.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.cuda import FloatTensor as CUDATensor 3 | 4 | _WINDOW_CASH = {} 5 | 6 | 7 | def visualize_image(vis, tensor, name, label=None, w=250, h=250, 8 | update_window_without_label=False): 9 | tensor = tensor.cpu() if isinstance(tensor, CUDATensor) else tensor 10 | title = name + ('-{}'.format(label) if label is not None else '') 11 | 12 | _WINDOW_CASH[title] = vis.image( 13 | tensor.numpy(), win=_WINDOW_CASH.get(title), 14 | opts=dict(title=title, width=w, height=h) 15 | ) 16 | 17 | # This is useful when you want to maintain the most recent images. 18 | if update_window_without_label: 19 | _WINDOW_CASH[name] = vis.image( 20 | tensor.numpy(), win=_WINDOW_CASH.get(name), 21 | opts=dict(title=name, width=w, height=h) 22 | ) 23 | 24 | 25 | def visualize_images(vis, tensor, name, label=None, w=250, h=250, 26 | update_window_without_label=False): 27 | tensor = tensor.cpu() if isinstance(tensor, CUDATensor) else tensor 28 | title = name + ('-{}'.format(label) if label is not None else '') 29 | 30 | _WINDOW_CASH[title] = vis.images( 31 | tensor.numpy(), win=_WINDOW_CASH.get(title), 32 | opts=dict(title=title, width=w, height=h) 33 | ) 34 | 35 | # This is useful when you want to maintain the most recent images. 36 | if update_window_without_label: 37 | _WINDOW_CASH[name] = vis.images( 38 | tensor.numpy(), win=_WINDOW_CASH.get(name), 39 | opts=dict(title=name, width=w, height=h) 40 | ) 41 | 42 | 43 | def visualize_kernel(vis, kernel, name, label=None, w=250, h=250, 44 | update_window_without_label=False, compress_tensor=False): 45 | # Do not visualize kernels that does not exists. 46 | if kernel is None: 47 | return 48 | 49 | assert len(kernel.size()) in (2, 4) 50 | title = name + ('-{}'.format(label) if label is not None else '') 51 | kernel = kernel.cpu() if isinstance(kernel, CUDATensor) else kernel 52 | kernel_norm = kernel if len(kernel.size()) == 2 else ( 53 | (kernel**2).mean(-1).mean(-1) if compress_tensor else 54 | kernel.view( 55 | kernel.size()[0] * kernel.size()[2], 56 | kernel.size()[1] * kernel.size()[3], 57 | ) 58 | ) 59 | kernel_norm = kernel_norm.abs() 60 | 61 | visualized = ( 62 | (kernel_norm - kernel_norm.min()) / 63 | (kernel_norm.max() - kernel_norm.min()) 64 | ).numpy() 65 | 66 | _WINDOW_CASH[title] = vis.image( 67 | visualized, win=_WINDOW_CASH.get(title), 68 | opts=dict(title=title, width=w, height=h) 69 | ) 70 | 71 | # This is useful when you want to maintain the most recent images. 72 | if update_window_without_label: 73 | _WINDOW_CASH[name] = vis.image( 74 | visualized, win=_WINDOW_CASH.get(name), 75 | opts=dict(title=name, width=w, height=h) 76 | ) 77 | 78 | 79 | def visualize_scalar(vis, scalar, name, iteration): 80 | visualize_scalars( 81 | vis, 82 | [scalar] if isinstance(scalar, float) or len(scalar) == 1 else scalar, 83 | [name], name, iteration 84 | ) 85 | 86 | 87 | def visualize_scalars(vis, scalars, names, title, iteration): 88 | assert len(scalars) == len(names) 89 | # Convert scalar tensors to numpy arrays. 90 | scalars, names = list(scalars), list(names) 91 | scalars = [s.cpu() if isinstance(s, CUDATensor) else s for s in scalars] 92 | scalars = [s.detach().numpy() if hasattr(s, 'numpy') else 93 | np.array([s]) for s in scalars] 94 | multi = len(scalars) > 1 95 | num = len(scalars) 96 | 97 | options = dict( 98 | fillarea=True, 99 | legend=names, 100 | width=400, 101 | height=400, 102 | xlabel='Iterations', 103 | ylabel=title, 104 | title=title, 105 | marginleft=30, 106 | marginright=30, 107 | marginbottom=80, 108 | margintop=30, 109 | ) 110 | 111 | X = ( 112 | np.column_stack(np.array([iteration] * num)) if multi else 113 | np.array([iteration] * num) 114 | ) 115 | Y = np.column_stack(scalars) if multi else scalars[0] 116 | 117 | if title in _WINDOW_CASH: 118 | vis.line( 119 | X=X, Y=Y, win=_WINDOW_CASH[title], opts=options, update='append' 120 | ) 121 | else: 122 | _WINDOW_CASH[title] = vis.line(X=X, Y=Y, opts=options) 123 | --------------------------------------------------------------------------------