├── .gitignore ├── Archives └── Images │ ├── mnist_disjoint_GAN_FID.png │ ├── mnist_disjoint_GAN_Fitting_Capacity.png │ ├── samples_ap_mnist.png │ └── task_explained.png ├── Classifiers ├── Cifar_Classifier.py ├── Classifier.py ├── Fashion_Classifier.py └── Mnist_Classifier.py ├── Data ├── __init__.py ├── data_loader.py ├── disjoint.py ├── download.py ├── fashion.py ├── input_pipeline.py ├── load_dataset.py ├── main_data.py ├── permutations.py └── rotations.py ├── Evaluation ├── Eval_Classifier.py ├── Reviewer.py ├── __init__.py └── tools.py ├── Generative_Models ├── BEGAN.py ├── CGAN.py ├── CVAE.py ├── Conditional_Model.py ├── GAN.py ├── Generative_Model.py ├── VAE.py ├── WGAN.py ├── WGAN_GP.py ├── __init__.py ├── discriminator.py ├── encoder.py └── generator.py ├── LICENSE ├── README.md ├── Scripts └── generate_test.sh ├── Training ├── Baseline.py ├── Ewc.py ├── Ewc_samples.py ├── Generative_Replay.py ├── README.md ├── Rehearsal.py └── Trainer.py ├── environment.yml ├── log_utils.py ├── main.py ├── print_figures.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /Archives/Images/mnist_disjoint_GAN_FID.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TLESORT/Generative_Continual_Learning/66b121437c248993b41f154b5a2d6b7197278578/Archives/Images/mnist_disjoint_GAN_FID.png -------------------------------------------------------------------------------- /Archives/Images/mnist_disjoint_GAN_Fitting_Capacity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TLESORT/Generative_Continual_Learning/66b121437c248993b41f154b5a2d6b7197278578/Archives/Images/mnist_disjoint_GAN_Fitting_Capacity.png -------------------------------------------------------------------------------- /Archives/Images/samples_ap_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TLESORT/Generative_Continual_Learning/66b121437c248993b41f154b5a2d6b7197278578/Archives/Images/samples_ap_mnist.png -------------------------------------------------------------------------------- /Archives/Images/task_explained.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TLESORT/Generative_Continual_Learning/66b121437c248993b41f154b5a2d6b7197278578/Archives/Images/task_explained.png -------------------------------------------------------------------------------- /Classifiers/Cifar_Classifier.py: -------------------------------------------------------------------------------- 1 | ''' Tooken from https://github.com/kuangliu/pytorch-cifar''' 2 | 3 | '''Dual Path Networks in PyTorch.''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from Classifiers.Classifier import Classifier, Net 9 | import torch.optim as optim 10 | 11 | 12 | class Cifar_Classifier(Classifier): 13 | def __init__(self, args): 14 | super(Cifar_Classifier, self).__init__(args) 15 | 16 | self.net = DPN26() 17 | if self.gpu_mode: 18 | self.net.cuda(self.device) 19 | self.optimizer = optim.Adam(params=self.net.parameters(), lr=self.args.lrC) 20 | 21 | 22 | class Bottleneck(nn.Module): 23 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer): 24 | super(Bottleneck, self).__init__() 25 | self.out_planes = out_planes 26 | self.dense_depth = dense_depth 27 | 28 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False) 29 | self.bn1 = nn.BatchNorm2d(in_planes) 30 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False) 31 | self.bn2 = nn.BatchNorm2d(in_planes) 32 | self.conv3 = nn.Conv2d(in_planes, out_planes + dense_depth, kernel_size=1, bias=False) 33 | self.bn3 = nn.BatchNorm2d(out_planes + dense_depth) 34 | 35 | self.shortcut = nn.Sequential() 36 | if first_layer: 37 | self.shortcut = nn.Sequential( 38 | nn.Conv2d(last_planes, out_planes + dense_depth, kernel_size=1, stride=stride, bias=False), 39 | nn.BatchNorm2d(out_planes + dense_depth) 40 | ) 41 | 42 | def forward(self, x): 43 | out = F.relu(self.bn1(self.conv1(x))) 44 | out = F.relu(self.bn2(self.conv2(out))) 45 | out = self.bn3(self.conv3(out)) 46 | x = self.shortcut(x) 47 | d = self.out_planes 48 | out = torch.cat([x[:, :d, :, :] + out[:, :d, :, :], x[:, d:, :, :], out[:, d:, :, :]], 1) 49 | out = F.relu(out) 50 | return out 51 | 52 | 53 | # class DPN(nn.Module): 54 | class Cifar_Net(nn.Module): 55 | def __init__(self, cfg): 56 | super(Cifar_Net, self).__init__() 57 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes'] 58 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth'] 59 | 60 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(64) 62 | self.last_planes = 64 63 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1) 64 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2) 65 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2) 66 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2) 67 | self.linear = nn.Linear(out_planes[3] + (num_blocks[3] + 1) * dense_depth[3], 10) 68 | 69 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride): 70 | strides = [stride] + [1] * (num_blocks - 1) 71 | layers = [] 72 | for i, stride in enumerate(strides): 73 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i == 0)) 74 | self.last_planes = out_planes + (i + 2) * dense_depth 75 | return nn.Sequential(*layers) 76 | 77 | def weights_init(self, m): 78 | classname = m.__class__.__name__ 79 | if classname.find('Conv') != -1: 80 | m.weight.data.normal_(0.0, 0.02) 81 | elif classname.find('BatchNorm') != -1: 82 | m.weight.data.normal_(1.0, 0.02) 83 | m.bias.data.fill_(0) 84 | 85 | def forward(self, x, FID=False): 86 | 87 | x = x.view(-1, 3, 32, 32) 88 | out = F.relu(self.bn1(self.conv1(x))) 89 | out = self.layer1(out) 90 | out = self.layer2(out) 91 | out = self.layer3(out) 92 | out = self.layer4(out) 93 | out = F.avg_pool2d(out, 4) 94 | out = out.view(out.size(0), -1) 95 | if FID: # size : 2432 96 | return out 97 | out = self.linear(out) 98 | return out 99 | 100 | 101 | def DPN26(): 102 | cfg = { 103 | 'in_planes': (96, 192, 384, 768), 104 | 'out_planes': (256, 512, 1024, 2048), 105 | 'num_blocks': (2, 2, 2, 2), 106 | 'dense_depth': (16, 32, 24, 128) 107 | } 108 | return Cifar_Net(cfg) 109 | 110 | 111 | def DPN92(): 112 | cfg = { 113 | 'in_planes': (96, 192, 384, 768), 114 | 'out_planes': (256, 512, 1024, 2048), 115 | 'num_blocks': (3, 4, 20, 3), 116 | 'dense_depth': (16, 32, 24, 128) 117 | } 118 | return Cifar_Net(cfg) 119 | 120 | 121 | def test(): 122 | net = DPN92() 123 | x = torch.randn(1, 3, 32, 32) 124 | y = net(x) 125 | print(y) 126 | -------------------------------------------------------------------------------- /Classifiers/Classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from utils import variable 6 | import torch.optim as optim 7 | from utils import variable 8 | import numpy as np 9 | import math 10 | 11 | 12 | class Classifier(object): 13 | def __init__(self, args): 14 | super(Classifier, self).__init__() 15 | 16 | self.args = args 17 | self.batchsize = args.batch_size 18 | self.gpu_mode = args.gpu_mode 19 | self.device = args.device 20 | self.save_dir = args.save_dir 21 | self.verbose = args.verbose 22 | 23 | self.net = Net() 24 | if self.gpu_mode: 25 | self.net.cuda(self.device) 26 | self.optimizer = optim.Adam(params=self.net.parameters(), lr=self.args.lrC) 27 | 28 | def train_on_task(self, train_loader, ind_task, epoch, additional_loss): 29 | self.net.train() 30 | epoch_loss = 0 31 | correct = 0 32 | train_loader.shuffle_task() 33 | for data, target in train_loader: 34 | data, target = variable(data), variable(target) 35 | 36 | if self.gpu_mode: 37 | data, target = data.cuda(self.device), target.cuda(self.device) 38 | 39 | self.optimizer.zero_grad() 40 | 41 | output = self.net(data) 42 | loss = F.cross_entropy(output, target) 43 | epoch_loss += loss.item() 44 | 45 | if additional_loss is not None: 46 | regularization = additional_loss(self.net) 47 | 48 | if regularization is not None: 49 | loss += regularization 50 | 51 | loss.backward() 52 | self.optimizer.step() 53 | correct += (output.max(dim=1)[1] == target).data.sum() 54 | 55 | if self.verbose: 56 | print('Train eval : task : ' + str(ind_task) + " - correct : " + str(correct) + ' / ' + str( 57 | len(train_loader))) 58 | 59 | return epoch_loss / np.float(len(train_loader)), 100. * correct / np.float(len(train_loader)) 60 | 61 | def eval_on_task(self, test_loader, verbose=False): 62 | self.net.eval() 63 | correct = 0 64 | val_loss_classif = 0 65 | 66 | classe_prediction = np.zeros(10) 67 | classe_total = np.zeros(10) 68 | classe_wrong = np.zeros(10) # Images wrongly attributed to a particular class 69 | 70 | for data, target in test_loader: 71 | batch = variable(data) 72 | label = variable(target.squeeze()) 73 | classif = self.net(batch) 74 | loss_classif = F.nll_loss(classif, label) 75 | val_loss_classif += loss_classif.item() 76 | pred = classif.data.max(1, keepdim=True)[1] # get the index of the max log-probability 77 | correct += pred.eq(label.data.view_as(pred)).cpu().sum() 78 | 79 | for i in range(label.data.shape[0]): 80 | if pred[i].cpu()[0] == label.data[i].cpu(): 81 | classe_prediction[pred[i].cpu()[0]] += 1 82 | else: 83 | classe_wrong[pred[i].cpu()[0]] += 1 84 | classe_total[label.data[i]] += 1 85 | 86 | val_loss_classif /= (np.float(len(test_loader.sampler))) 87 | valid_accuracy = 100. * correct / np.float(len(test_loader.sampler)) 88 | 89 | if verbose: 90 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format( 91 | val_loss_classif, correct, (len(test_loader)*self.batchsize), 92 | 100. * correct / (len(test_loader)*self.batchsize))) 93 | 94 | for i in range(10): 95 | print('Classe {} Accuracy: {}/{} ({:.3f}%, Wrong : {})'.format( 96 | i, classe_prediction[i], classe_total[i], 97 | 100. * classe_prediction[i] / classe_total[i], classe_wrong[i])) 98 | print('\n') 99 | 100 | return val_loss_classif, valid_accuracy, classe_prediction, classe_total, classe_wrong 101 | 102 | def forward(self, x, FID=False): 103 | return self.net.forward(x, FID) 104 | 105 | def save(self, ind_task, Best=False): 106 | if not os.path.exists(self.save_dir): 107 | os.makedirs(self.save_dir) 108 | 109 | if Best: 110 | torch.save(self.net.state_dict(), os.path.join(self.save_dir, 'Best_Classifier.pkl')) 111 | else: 112 | torch.save(self.net.state_dict(), os.path.join(self.save_dir, 'Classifier_' + str(ind_task) + '.pkl')) 113 | 114 | def load_expert(self): 115 | 116 | expert_path = os.path.join(self.save_dir, '..', '..', '..', '..', '..', '..', '..', 'Classification', 117 | self.args.dataset, 118 | 'Baseline', 'Num_tasks_1', 'seed_' + str(self.args.seed), 'Best_Classifier.pkl') 119 | 120 | if not os.path.exists(os.path.join(expert_path)): 121 | print('The expert does not exist, you can train it by running :') 122 | print( 123 | 'python main.py --context Classification --task_type disjoint --method Baseline --dataset YOUR_DATASET --epochs 25 --num_task 1 --seed YOUR_SEED') 124 | 125 | self.net.load_state_dict(torch.load(expert_path)) 126 | 127 | def labelize(self, batch, nb_classes): 128 | self.net.eval() 129 | if self.gpu_mode: 130 | batch = batch.cuda(self.device) 131 | output = self.net(batch) 132 | 133 | return output[:, :nb_classes].max(dim=1)[1] 134 | 135 | def reinit(self): 136 | self.net.apply(Xavier) 137 | 138 | 139 | class Net(nn.Module): 140 | def __init__(self): 141 | super(Net, self).__init__() 142 | self.input_dim = 1 143 | self.output_dim = 1 144 | self.relu = nn.ReLU() 145 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 146 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 147 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 148 | self.conv2_drop = nn.Dropout2d() 149 | self.dropout = nn.Dropout(p=0.5) 150 | 151 | self.BN = nn.BatchNorm1d(320) 152 | self.fc1 = nn.Linear(320, 50) 153 | self.fc2 = nn.Linear(50, 10) 154 | self.apply(Xavier) 155 | 156 | def forward(self, x, FID=False): 157 | x = x.view(-1, 1, 28, 28) 158 | 159 | x = self.relu(self.maxpool2(self.conv1(x))) 160 | x = self.relu(self.maxpool2(self.conv2_drop(self.conv2(x)))) 161 | x = x.view(-1, 320) 162 | x = self.BN(x) 163 | if FID: 164 | return x 165 | x = self.relu(self.fc1(x)) 166 | x = self.fc2(x) 167 | return F.log_softmax(x, dim=1) 168 | 169 | def weights_init(self, m): 170 | classname = m.__class__.__name__ 171 | if classname.find('Conv') != -1: 172 | m.weight.data.normal_(0.0, 0.02) 173 | elif classname.find('BatchNorm') != -1: 174 | m.weight.data.normal_(1.0, 0.02) 175 | m.bias.data.fill_(0) 176 | 177 | 178 | def Xavier(m): 179 | if m.__class__.__name__ == 'Linear': 180 | fan_in, fan_out = m.weight.data.size(1), m.weight.data.size(0) 181 | std = 1.0 * math.sqrt(2.0 / (fan_in + fan_out)) 182 | a = math.sqrt(3.0) * std 183 | m.weight.data.uniform_(-a, a) 184 | m.bias.data.fill_(0.0) 185 | -------------------------------------------------------------------------------- /Classifiers/Fashion_Classifier.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | from Classifiers.Classifier import Classifier, Net 7 | 8 | 9 | class Fashion_Classifier(Classifier): 10 | 11 | def __init__(self, args): 12 | super(Fashion_Classifier, self).__init__(args) -------------------------------------------------------------------------------- /Classifiers/Mnist_Classifier.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from Classifiers.Classifier import Classifier, Net 4 | 5 | class Mnist_Classifier(Classifier): 6 | def __init__(self, args): 7 | super(Mnist_Classifier, self).__init__(args) 8 | 9 | class Mnist_Net(Net): 10 | 11 | def __init__(self): 12 | super(Mnist_Net, self).__init__() -------------------------------------------------------------------------------- /Data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TLESORT/Generative_Continual_Learning/66b121437c248993b41f154b5a2d6b7197278578/Data/__init__.py -------------------------------------------------------------------------------- /Data/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from log_utils import * 4 | 5 | class DataLoader(object): 6 | def __init__(self, data, args): 7 | 8 | ''' 9 | 10 | dataset.shape = [num , 3, image_number] 11 | dataset[0 , 1, :] # all data from task 0 12 | dataset[0 , 2, :] # all label from task 0 13 | 14 | ''' 15 | 16 | self.dataset = data 17 | self.batch_size = args.batch_size 18 | n_tasks = args.num_task 19 | self.length = n_tasks 20 | self.current_sample = 0 21 | self.current_task = 0 22 | self.sampler = self 23 | 24 | def __iter__(self): 25 | return self 26 | 27 | def next(self): 28 | return self.__next__() 29 | 30 | def __next__(self): 31 | ''' 32 | 33 | :return: (data, label) with shape batch_size 34 | 35 | ''' 36 | 37 | if self.current_sample == self.dataset[self.current_task][1].shape[0]: 38 | self.current_sample = 0 # reinitialize 39 | self.shuffle_task() 40 | raise StopIteration 41 | elif self.current_sample + self.batch_size >= self.dataset[self.current_task][1].shape[0]: 42 | last_size = self.dataset[self.current_task][1].shape[0] - self.current_sample 43 | j = range(self.current_sample, self.current_sample + last_size) 44 | self.current_sample = self.current_sample + last_size 45 | j = torch.LongTensor(j) 46 | return self.dataset[self.current_task][1][j], self.dataset[self.current_task][2][j] 47 | else: 48 | j = range(self.current_sample, self.current_sample + self.batch_size) 49 | self.current_sample = self.current_sample + self.batch_size 50 | j = torch.LongTensor(j) 51 | return self.dataset[self.current_task][1][j], self.dataset[self.current_task][2][j] 52 | 53 | def __len__(self): 54 | return len(self.dataset[self.current_task][1]) 55 | 56 | def __getitem__(self, key): 57 | self.current_sample = 0 58 | self.current_task = key 59 | return self 60 | 61 | 62 | def shuffle_task(self): 63 | indices = torch.randperm(len(self.dataset[self.current_task][1])) 64 | self.dataset[self.current_task][1] = self.dataset[self.current_task][1][indices].clone() 65 | self.dataset[self.current_task][2] = self.dataset[self.current_task][2][indices].clone() 66 | 67 | def get_sample(self, number): 68 | indices = torch.randperm(len(self))[0:number] 69 | 70 | return self.dataset[self.current_task][1][indices], self.dataset[self.current_task][2][indices] 71 | 72 | def concatenate(self, new_data, task=0): 73 | 74 | ''' 75 | 76 | :param new_data: data to add to the actual task 77 | :return: the actual dataset with supplementary data inside 78 | ''' 79 | 80 | self.dataset[self.current_task][1] = torch.cat((self.dataset[self.current_task][1], new_data.dataset[task][1]), 0).clone() 81 | self.dataset[self.current_task][2] = torch.cat((self.dataset[self.current_task][2], new_data.dataset[task][2]), 0).clone() 82 | 83 | return self 84 | 85 | def get_current_task(self): 86 | return self.current_task 87 | 88 | def save(self, path): 89 | torch.save(self.dataset, path) 90 | 91 | def visualize_sample(self, path , number, shape): 92 | data, target = self.get_sample(number) 93 | 94 | # get sample in order from 0 to 9 95 | target, order = target.sort() 96 | data = data[order] 97 | 98 | image_frame_dim = int(np.floor(np.sqrt(number))) 99 | 100 | if shape[2] == 1: 101 | data = data.numpy().reshape(number, shape[0], shape[1], shape[2]) 102 | save_images(data[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 103 | path) 104 | else: 105 | data = data.numpy().reshape(number, shape[2], shape[1], shape[0]) 106 | make_samples_batche(data[:number], number, path) 107 | 108 | def increase_size(self, increase_factor): 109 | 110 | self.dataset[self.current_task][1] = torch.cat([self.dataset[self.current_task][1]]*increase_factor, 0) 111 | self.dataset[self.current_task][2] = torch.cat([self.dataset[self.current_task][2]]*increase_factor, 0) 112 | 113 | return self 114 | 115 | 116 | -------------------------------------------------------------------------------- /Data/disjoint.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path 3 | import torch 4 | 5 | import numpy as np 6 | 7 | from torchvision import datasets, transforms 8 | 9 | 10 | class Disjoint(object): 11 | def __init__(self, args): 12 | super(Disjoint, self).__init__() 13 | 14 | self.upperbound = args.upperbound 15 | self.n_tasks = args.n_tasks 16 | self.i = args.i 17 | self.train_file = args.train_file 18 | self.test_file = args.test_file 19 | self.dataset = args.dataset 20 | 21 | if self.upperbound: 22 | self.o_train = os.path.join(args.o, 'upperbound_disjoint_' + str(self.n_tasks) + '_train.pt') 23 | self.o_test = os.path.join(args.o, 'upperbound_disjoint_' + str(self.n_tasks) + '_test.pt') 24 | else: 25 | self.o_train = os.path.join(args.o, 'disjoint_' + str(self.n_tasks) + '_train.pt') 26 | self.o_test = os.path.join(args.o, 'disjoint_' + str(self.n_tasks) + '_test.pt') 27 | 28 | 29 | def load_cifar10(self): 30 | transform_train = transforms.Compose([ 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ToTensor(), 33 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 34 | ]) 35 | dataset_train = datasets.CIFAR10(root='./Datasets', train=True, download=True, transform=transform_train) 36 | tensor_data = torch.Tensor(len(dataset_train),3,32,32) 37 | tensor_label = torch.LongTensor(len(dataset_train)) 38 | 39 | for i in range(len(dataset_train)): 40 | tensor_data[i] = dataset_train[i][0] 41 | tensor_label[i] = dataset_train[i][1] 42 | 43 | transform_test = transforms.Compose([ 44 | transforms.ToTensor(), 45 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 46 | ]) 47 | 48 | dataset_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 49 | 50 | tensor_test = torch.Tensor(len(dataset_test),3,32,32) 51 | tensor_label_test = torch.LongTensor(len(dataset_test)) 52 | 53 | for i in range(len(dataset_test)): 54 | tensor_test[i] = dataset_test[i][0] 55 | tensor_label_test[i] = dataset_test[i][1] 56 | 57 | #testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 58 | 59 | return tensor_data, tensor_label, tensor_test, tensor_label_test 60 | 61 | 62 | 63 | def formating_data(self): 64 | 65 | 66 | 67 | tasks_tr = [] 68 | tasks_te = [] 69 | 70 | if self.dataset == 'cifar10': 71 | x_tr, y_tr, x_te, y_te = self.load_cifar10() 72 | 73 | x_tr = x_tr.float().view(x_tr.size(0), -1) 74 | x_te = x_te.float().view(x_te.size(0), -1) 75 | else: 76 | assert os.path.isfile(os.path.join(self.i, self.train_file)), print(os.path.join(self.i, self.train_file)) 77 | assert os.path.isfile(os.path.join(self.i, self.test_file)), print(os.path.join(self.i, self.test_file)) 78 | 79 | x_tr, y_tr = torch.load(os.path.join(self.i, self.train_file)) 80 | x_te, y_te = torch.load(os.path.join(self.i, self.test_file)) 81 | 82 | x_tr = x_tr.float().view(x_tr.size(0), -1) / 255.0 83 | x_te = x_te.float().view(x_te.size(0), -1) / 255.0 84 | 85 | y_tr = y_tr.view(-1).long() 86 | y_te = y_te.view(-1).long() 87 | 88 | cpt = int(10 / self.n_tasks) 89 | 90 | for t in range(self.n_tasks): 91 | if self.upperbound: 92 | c1 = 0 93 | else: 94 | c1 = t * cpt 95 | c2 = (t + 1) * cpt 96 | i_tr = ((y_tr >= c1) & (y_tr < c2)).nonzero().view(-1) 97 | i_te = ((y_te >= c1) & (y_te < c2)).nonzero().view(-1) 98 | tasks_tr.append([(c1, c2), x_tr[i_tr].clone(), y_tr[i_tr].clone()]) 99 | tasks_te.append([(c1, c2), x_te[i_te].clone(), y_te[i_te].clone()]) 100 | 101 | torch.save(tasks_tr, self.o_train) 102 | torch.save(tasks_te, self.o_test) 103 | 104 | 105 | if __name__ == '__main__': 106 | parser = argparse.ArgumentParser() 107 | 108 | parser.add_argument('--i', default='raw/cifar100.pt', help='input directory') 109 | parser.add_argument('--o', default='cifar100.pt', help='output file') 110 | parser.add_argument('--n_tasks', default=10, type=int, help='number of tasks') 111 | parser.add_argument('--seed', default=0, type=int, help='random seed') 112 | args = parser.parse_args() 113 | 114 | torch.manual_seed(args.seed) 115 | 116 | DataFormater = Disjoint() 117 | DataFormater.formating_data(args) 118 | -------------------------------------------------------------------------------- /Data/download.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Modification of https://github.com/stanfordnlp/treelstm/blob/master/scripts/download.py 4 | Downloads the following: 5 | - Celeb-A dataset 6 | - LSUN dataset 7 | - MNIST dataset 8 | """ 9 | 10 | from __future__ import print_function 11 | import os 12 | import sys 13 | import gzip 14 | import json 15 | import shutil 16 | import zipfile 17 | import argparse 18 | import requests 19 | import subprocess 20 | from tqdm import tqdm 21 | from six.moves import urllib 22 | 23 | parser = argparse.ArgumentParser(description='Download dataset for DCGAN.') 24 | parser.add_argument('datasets', metavar='N', type=str, nargs='+', choices=['celebA', 'lsun', 'mnist'], 25 | help='name of dataset to download [celebA, lsun, mnist]') 26 | 27 | def download(url, dirpath): 28 | filename = url.split('/')[-1] 29 | filepath = os.path.join(dirpath, filename) 30 | u = urllib.request.urlopen(url) 31 | f = open(filepath, 'wb') 32 | filesize = int(u.headers["Content-Length"]) 33 | print("Downloading: %s Bytes: %s" % (filename, filesize)) 34 | 35 | downloaded = 0 36 | block_sz = 8192 37 | status_width = 70 38 | while True: 39 | buf = u.read(block_sz) 40 | if not buf: 41 | print('') 42 | break 43 | else: 44 | print('', end='\r') 45 | downloaded += len(buf) 46 | f.write(buf) 47 | status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") % 48 | ('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize)) 49 | print(status, end='') 50 | sys.stdout.flush() 51 | f.close() 52 | return filepath 53 | 54 | def download_file_from_google_drive(id, destination): 55 | URL = "https://docs.google.com/uc?export=download" 56 | session = requests.Session() 57 | 58 | response = session.get(URL, params={ 'id': id }, stream=True) 59 | token = get_confirm_token(response) 60 | 61 | if token: 62 | params = { 'id' : id, 'confirm' : token } 63 | response = session.get(URL, params=params, stream=True) 64 | 65 | save_response_content(response, destination) 66 | 67 | def get_confirm_token(response): 68 | for key, value in response.cookies.items(): 69 | if key.startswith('download_warning'): 70 | return value 71 | return None 72 | 73 | def save_response_content(response, destination, chunk_size=32*1024): 74 | total_size = int(response.headers.get('content-length', 0)) 75 | with open(destination, "wb") as f: 76 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size, 77 | unit='B', unit_scale=True, desc=destination): 78 | if chunk: # filter out keep-alive new chunks 79 | f.write(chunk) 80 | 81 | def unzip(filepath): 82 | print("Extracting: " + filepath) 83 | dirpath = os.path.dirname(filepath) 84 | with zipfile.ZipFile(filepath) as zf: 85 | zf.extractall(dirpath) 86 | os.remove(filepath) 87 | 88 | def download_celeb_a(dirpath): 89 | data_dir = 'celebA' 90 | if os.path.exists(os.path.join(dirpath, data_dir)): 91 | print('Found Celeb-A - skip') 92 | return 93 | 94 | filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM" 95 | save_path = os.path.join(dirpath, filename) 96 | 97 | if os.path.exists(save_path): 98 | print('[*] {} already exists'.format(save_path)) 99 | else: 100 | download_file_from_google_drive(drive_id, save_path) 101 | 102 | zip_dir = '' 103 | with zipfile.ZipFile(save_path) as zf: 104 | zip_dir = zf.namelist()[0] 105 | zf.extractall(dirpath) 106 | os.remove(save_path) 107 | os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir)) 108 | 109 | def _list_categories(tag): 110 | url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag 111 | f = urllib.request.urlopen(url) 112 | return json.loads(f.read()) 113 | 114 | def _download_lsun(out_dir, category, set_name, tag): 115 | url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \ 116 | '&category={category}&set={set_name}'.format(**locals()) 117 | print(url) 118 | if set_name == 'test': 119 | out_name = 'test_lmdb.zip' 120 | else: 121 | out_name = '{category}_{set_name}_lmdb.zip'.format(**locals()) 122 | out_path = os.path.join(out_dir, out_name) 123 | cmd = ['curl', url, '-o', out_path] 124 | print('Downloading', category, set_name, 'set') 125 | subprocess.call(cmd) 126 | 127 | def download_lsun(dirpath): 128 | data_dir = os.path.join(dirpath, 'lsun') 129 | if os.path.exists(data_dir): 130 | print('Found LSUN - skip') 131 | return 132 | else: 133 | os.mkdir(data_dir) 134 | 135 | tag = 'latest' 136 | #categories = _list_categories(tag) 137 | categories = ['bedroom'] 138 | 139 | for category in categories: 140 | _download_lsun(data_dir, category, 'train', tag) 141 | _download_lsun(data_dir, category, 'val', tag) 142 | _download_lsun(data_dir, '', 'test', tag) 143 | 144 | def download_mnist(dirpath): 145 | data_dir = os.path.join(dirpath, 'mnist') 146 | if os.path.exists(data_dir): 147 | print('Found MNIST - skip') 148 | return 149 | else: 150 | os.mkdir(data_dir) 151 | url_base = 'http://yann.lecun.com/exdb/mnist/' 152 | file_names = ['train-images-idx3-ubyte.gz', 153 | 'train-labels-idx1-ubyte.gz', 154 | 't10k-images-idx3-ubyte.gz', 155 | 't10k-labels-idx1-ubyte.gz'] 156 | for file_name in file_names: 157 | url = (url_base+file_name).format(**locals()) 158 | print(url) 159 | out_path = os.path.join(data_dir,file_name) 160 | cmd = ['curl', url, '-o', out_path] 161 | print('Downloading ', file_name) 162 | subprocess.call(cmd) 163 | cmd = ['gzip', '-d', out_path] 164 | print('Decompressing ', file_name) 165 | subprocess.call(cmd) 166 | 167 | def prepare_data_dir(path = './data'): 168 | if not os.path.exists(path): 169 | os.mkdir(path) 170 | 171 | if __name__ == '__main__': 172 | args = parser.parse_args() 173 | prepare_data_dir() 174 | 175 | if any(name in args.datasets for name in ['CelebA', 'celebA', 'celebA']): 176 | download_celeb_a('./celebA_data') 177 | if 'lsun' in args.datasets: 178 | download_lsun('./data') 179 | if 'mnist' in args.datasets: 180 | download_mnist('./data') -------------------------------------------------------------------------------- /Data/fashion.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.utils.data as data 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import errno 7 | import torch 8 | import codecs 9 | 10 | 11 | class fashion(data.Dataset): 12 | """`MNIST `_ Dataset. 13 | Args: 14 | root (string): Root directory of dataset where ``processed/training.pt`` 15 | and ``processed/test.pt`` exist. 16 | train (bool, optional): If True, creates dataset from ``training.pt``, 17 | otherwise from ``test.pt``. 18 | download (bool, optional): If true, downloads the dataset from the internet and 19 | puts it in root directory. If dataset is already downloaded, it is not 20 | downloaded again. 21 | transform (callable, optional): A function/transform that takes in an PIL image 22 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 23 | target_transform (callable, optional): A function/transform that takes in the 24 | target and transforms it. 25 | """ 26 | urls = [ 27 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz', 28 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz', 29 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz', 30 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz', 31 | ] 32 | raw_folder = 'raw' 33 | processed_folder = 'processed' 34 | training_file = 'training.pt' 35 | test_file = 'test.pt' 36 | 37 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False): 38 | self.root = os.path.expanduser(root) 39 | self.transform = transform 40 | self.target_transform = target_transform 41 | self.train = train # training set or test set 42 | 43 | if download: 44 | self.download() 45 | 46 | if not self._check_exists(): 47 | raise RuntimeError('Dataset not found.' + 48 | ' You can use download=True to download it') 49 | 50 | if self.train: 51 | self.train_data, self.train_labels = torch.load( 52 | os.path.join(root, self.processed_folder, self.training_file)) 53 | else: 54 | self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file)) 55 | 56 | def __getitem__(self, index): 57 | """ 58 | Args: 59 | index (int): Index 60 | Returns: 61 | tuple: (image, target) where target is index of the target class. 62 | """ 63 | if self.train: 64 | img, target = self.train_data[index], self.train_labels[index] 65 | else: 66 | img, target = self.test_data[index], self.test_labels[index] 67 | 68 | # doing this so that it is consistent with all other datasets 69 | # to return a PIL Image 70 | img = Image.fromarray(img.numpy(), mode='L') 71 | 72 | if self.transform is not None: 73 | img = self.transform(img) 74 | 75 | if self.target_transform is not None: 76 | target = self.target_transform(target) 77 | 78 | return img, target 79 | 80 | def __len__(self): 81 | if self.train: 82 | return len(self.train_data) 83 | else: 84 | return len(self.test_data) 85 | 86 | def _check_exists(self): 87 | 88 | return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ 89 | os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file)) 90 | 91 | def download(self): 92 | """Download the MNIST data if it doesn't exist in processed_folder already.""" 93 | from six.moves import urllib 94 | import gzip 95 | 96 | if self._check_exists(): 97 | return 98 | 99 | # download files 100 | try: 101 | os.makedirs(os.path.join(self.root, self.raw_folder)) 102 | os.makedirs(os.path.join(self.root, self.processed_folder)) 103 | except OSError as e: 104 | if e.errno == errno.EEXIST: 105 | pass 106 | else: 107 | raise 108 | 109 | for url in self.urls: 110 | print('Downloading ' + url) 111 | data = urllib.request.urlopen(url) 112 | filename = url.rpartition('/')[2] 113 | file_path = os.path.join(self.root, self.raw_folder, filename) 114 | with open(file_path, 'wb') as f: 115 | f.write(data.read()) 116 | with open(file_path.replace('.gz', ''), 'wb') as out_f, \ 117 | gzip.GzipFile(file_path) as zip_f: 118 | out_f.write(zip_f.read()) 119 | os.unlink(file_path) 120 | 121 | # process and save as torch files 122 | print('Processing...') 123 | 124 | training_set = ( 125 | read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')), 126 | read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte')) 127 | ) 128 | test_set = ( 129 | read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')), 130 | read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte')) 131 | ) 132 | with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f: 133 | torch.save(training_set, f) 134 | with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f: 135 | torch.save(test_set, f) 136 | 137 | print('Done!') 138 | 139 | 140 | def get_int(b): 141 | return int(codecs.encode(b, 'hex'), 16) 142 | 143 | 144 | def parse_byte(b): 145 | if isinstance(b, str): 146 | return ord(b) 147 | return b 148 | 149 | 150 | def read_label_file(path): 151 | with open(path, 'rb') as f: 152 | data = f.read() 153 | assert get_int(data[:4]) == 2049 154 | length = get_int(data[4:8]) 155 | labels = [parse_byte(b) for b in data[8:]] 156 | assert len(labels) == length 157 | return torch.LongTensor(labels) 158 | 159 | 160 | def read_image_file(path): 161 | with open(path, 'rb') as f: 162 | data = f.read() 163 | assert get_int(data[:4]) == 2051 164 | length = get_int(data[4:8]) 165 | num_rows = get_int(data[8:12]) 166 | num_cols = get_int(data[12:16]) 167 | images = [] 168 | idx = 16 169 | for l in range(length): 170 | img = [] 171 | images.append(img) 172 | for r in range(num_rows): 173 | row = [] 174 | img.append(row) 175 | for c in range(num_cols): 176 | row.append(parse_byte(data[idx])) 177 | idx += 1 178 | assert len(images) == length 179 | return torch.ByteTensor(images).view(-1, 28, 28) 180 | 181 | -------------------------------------------------------------------------------- /Data/input_pipeline.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | from PIL import Image, ImageEnhance 5 | from torchvision.datasets import ImageFolder 6 | import torchvision.transforms as transforms 7 | 8 | from torch.utils.data.dataset import Dataset 9 | 10 | import os 11 | 12 | def get_annotations_map(VAL_PATH): 13 | valAnnotationsPath = VAL_PATH + '/val_annotations.txt' 14 | valAnnotationsFile = open(valAnnotationsPath, 'r') 15 | valAnnotationsContents = valAnnotationsFile.read() 16 | valAnnotations = {} 17 | 18 | for line in valAnnotationsContents.splitlines(): 19 | pieces = line.strip().split() 20 | valAnnotations[pieces[0]] = pieces[1] 21 | return valAnnotations 22 | 23 | 24 | def get_image_folders(TRAIN_DIR): 25 | """ 26 | Build an input pipeline for training and evaluation. 27 | For training data it does data augmentation. 28 | """ 29 | 30 | enhancers = { 31 | 0: lambda image, f: ImageEnhance.Color(image).enhance(f), 32 | 1: lambda image, f: ImageEnhance.Contrast(image).enhance(f), 33 | 2: lambda image, f: ImageEnhance.Brightness(image).enhance(f), 34 | 3: lambda image, f: ImageEnhance.Sharpness(image).enhance(f) 35 | } 36 | 37 | # intensities of enhancers 38 | factors = { 39 | 0: lambda: np.clip(np.random.normal(1.0, 0.3), 0.4, 1.6), 40 | 1: lambda: np.clip(np.random.normal(1.0, 0.15), 0.7, 1.3), 41 | 2: lambda: np.clip(np.random.normal(1.0, 0.15), 0.7, 1.3), 42 | 3: lambda: np.clip(np.random.normal(1.0, 0.3), 0.4, 1.6), 43 | } 44 | 45 | # randomly change color of an image 46 | def enhance(image): 47 | order = [0, 1, 2, 3] 48 | np.random.shuffle(order) 49 | # random enhancers in random order 50 | for i in order: 51 | f = factors[i]() 52 | image = enhancers[i](image, f) 53 | return image 54 | 55 | def rotate(image): 56 | degree = np.clip(np.random.normal(0.0, 15.0), -40.0, 40.0) 57 | return image.rotate(degree, Image.BICUBIC) 58 | 59 | # training data augmentation on the fly 60 | train_transform = transforms.Compose([ 61 | transforms.Lambda(rotate), 62 | #transforms.RandomCrop(56), 63 | transforms.RandomHorizontalFlip(), 64 | transforms.Lambda(enhance), 65 | transforms.ToTensor(), 66 | transforms.Normalize( 67 | mean=[0.485, 0.456, 0.406], 68 | std=[0.229, 0.224, 0.225] 69 | ), 70 | ]) 71 | 72 | # mean and std are taken from here: 73 | # http://pytorch.org/docs/master/torchvision/models.html 74 | train_folder = ImageFolder(TRAIN_DIR, train_transform) 75 | 76 | return train_folder 77 | 78 | def get_test_image_folders(path): 79 | 80 | 81 | num_classes = 10 82 | TRAIN_DIR=path+'tiny-imagenet-200/training' 83 | VAL_DIR=path+'tiny-imagenet-200/validation' 84 | 85 | 86 | # for validation data 87 | val_transform = transforms.Compose([ 88 | #transforms.CenterCrop(56), 89 | transforms.ToTensor(), 90 | transforms.Normalize( 91 | mean=[0.485, 0.456, 0.406], 92 | std=[0.229, 0.224, 0.225] 93 | ), 94 | ]) 95 | 96 | val_annotations_map = get_annotations_map(VAL_DIR) 97 | 98 | #val_folder = ImageFolder(VAL_DIR, val_transform) 99 | 100 | 101 | #X_train = np.zeros([num_classes * 500, 3, 64, 64], dtype='uint8') 102 | #y_train = np.zeros([num_classes * 500], dtype='uint8') 103 | 104 | trainPath = TRAIN_DIR 105 | 106 | i = 0 107 | j = 0 108 | annotations = {} 109 | for sChild in os.listdir(trainPath): 110 | sChildPath = os.path.join(os.path.join(trainPath, sChild), 'images') 111 | annotations[sChild] = j 112 | ''' 113 | for c in os.listdir(sChildPath): 114 | X = np.array(Image.open(os.path.join(sChildPath, c))) 115 | if len(np.shape(X)) == 2: 116 | X_train[i] = np.array([X, X, X]) 117 | else: 118 | X_train[i] = np.transpose(X, (2, 0, 1)) 119 | y_train[i] = j 120 | i += 1 121 | ''' 122 | j += 1 123 | if (j >= num_classes): 124 | break 125 | 126 | print('loading test images...') 127 | 128 | X_test = np.zeros([num_classes * 50, 3, 64, 64], dtype='uint8') 129 | y_test = np.zeros([num_classes * 50], dtype='uint8') 130 | 131 | i = 0 132 | testPath = VAL_DIR + '/images' 133 | for sChild in os.listdir(testPath): 134 | if val_annotations_map[sChild] in annotations.keys(): 135 | sChildPath = os.path.join(testPath, sChild) 136 | X = np.array(Image.open(sChildPath)) 137 | if len(np.shape(X)) == 2: 138 | X_test[i] = np.array([X, X, X]) 139 | else: 140 | X_test[i] = np.transpose(X, (2, 0, 1)) 141 | y_test[i] = annotations[val_annotations_map[sChild]] 142 | i += 1 143 | else: 144 | pass 145 | 146 | return DataTest(torch.from_numpy(X_test), torch.from_numpy(y_test)), y_test 147 | 148 | class DataTest(Dataset): 149 | 150 | def __init__(self, data_tensor, target_tensor): 151 | assert data_tensor.size(0) == target_tensor.size(0) 152 | 153 | print(data_tensor.type(torch.FloatTensor).shape) 154 | 155 | self.data_tensor = data_tensor.type(torch.FloatTensor) 156 | self.target_tensor = target_tensor.type(torch.LongTensor) 157 | 158 | def __getitem__(self, index): 159 | return self.data_tensor[index], self.target_tensor[index] 160 | 161 | def __len__(self): 162 | 163 | return self.data_tensor.size(0) 164 | 165 | # there is no annotation in this test set , therefor we can not use it for evaluation 166 | ''' 167 | def get_test_image_folders(TEST_DIR): 168 | """ 169 | Build an input pipeline for training and evaluation. 170 | For training data it does data augmentation. 171 | """ 172 | 173 | 174 | # for validation data 175 | test_transform = transforms.Compose([ 176 | #transforms.CenterCrop(56), 177 | transforms.ToTensor(), 178 | transforms.Normalize( 179 | mean=[0.485, 0.456, 0.406], 180 | std=[0.229, 0.224, 0.225] 181 | ), 182 | ]) 183 | 184 | # mean and std are taken from here: 185 | # http://pytorch.org/docs/master/torchvision/models.html 186 | test_folder = ImageFolder(TEST_DIR, test_transform) 187 | return test_folder 188 | ''' -------------------------------------------------------------------------------- /Data/load_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data.dataset import Dataset 7 | from torchvision import datasets, transforms 8 | 9 | import utils 10 | from Data.fashion import fashion 11 | from Data.input_pipeline import get_image_folders, get_test_image_folders 12 | 13 | 14 | class Subset(Dataset): 15 | def __init__(self, dataset, indices): 16 | self.dataset = dataset 17 | self.indices = indices 18 | 19 | def __getitem__(self, idx): 20 | return self.dataset[self.indices[idx]] 21 | 22 | def __len__(self): 23 | return len(self.indices) 24 | 25 | 26 | def load_dataset_full(data_dir, dataset, num_examples=50000): 27 | 28 | fas=False 29 | path = os.path.join(data_dir, 'Datasets', dataset) 30 | 31 | if dataset == 'mnist': 32 | dataset = datasets.MNIST(path, train=True, download=True, transform=transforms.ToTensor()) 33 | dataset_train = Subset(dataset, range(num_examples)) 34 | dataset_val = Subset(dataset, range(50000, 60000)) 35 | elif dataset == 'fashion': 36 | if fas: 37 | dataset = datasets.FashionMNIST(path, train=True, download=True, transform=transforms.ToTensor()) 38 | else: 39 | 40 | dataset = fashion(path, train=True, download=True, transform=transforms.ToTensor()) 41 | dataset_train = Subset(dataset, range(num_examples)) 42 | dataset_val = Subset(dataset, range(50000, 60000)) 43 | elif dataset == 'cifar10': 44 | if num_examples > 45000: num_examples = 45000 # does not work if num_example > 50000 45 | transform = transforms.Compose( 46 | [transforms.ToTensor()]) 47 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 48 | dataset = datasets.CIFAR10(root=path, train=True, download=True, transform=transform) 49 | dataset_train = Subset(dataset, range(num_examples)) 50 | dataset_val = Subset(dataset, range(45000, 50000)) 51 | elif dataset == 'lsun': 52 | transform = transforms.Compose([ 53 | transforms.Scale(64), 54 | transforms.CenterCrop(64), 55 | transforms.ToTensor(), 56 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 57 | ]) 58 | dataset_train = datasets.LSUN(db_path=path+'/LSUN/', classes=['bedroom_train', 'bridge_train', 'church_outdoor_train', 'classroom_train', 59 | 'conference_room_train', 'dining_room_train', 'kitchen_train', 60 | 'living_room_train', 'restaurant_train', 'tower_train'],transform=transform) 61 | 62 | dataset_val = datasets.LSUN(db_path=path+'/LSUN/', classes=['bedroom_val', 'bridge_val', 'church_outdoor_val', 'classroom_val', 63 | 'conference_room_val', 'dining_room_val', 'kitchen_val', 64 | 'living_room_val', 'restaurant_val', 'tower_val'],transform=transform) 65 | elif dataset == 'timagenet': 66 | dataset = get_image_folders(path+'tiny-imagenet-200/training') 67 | 68 | size = len(dataset) 69 | indices = torch.randperm(size) 70 | 71 | dataset_train = Subset(dataset, indices[:int(size*0.8)]) 72 | dataset_val = Subset(dataset, indices[int(size*0.8):]) 73 | 74 | 75 | 76 | list_classes_train = np.asarray([dataset_train[i][1] for i in range(len(dataset_train))]) 77 | list_classes_val = np.asarray([dataset_val[i][1] for i in range(len(dataset_val))]) 78 | 79 | if dataset == 'timagenet': 80 | #we only use 10 classes in the dataset 81 | list_classes_train = np.where(list_classes_train < 10)[0] 82 | list_classes_val = np.where(list_classes_val < 10)[0] 83 | 84 | dataset_train = Subset(dataset_val, list_classes_train) 85 | dataset_val = Subset(dataset_val, list_classes_train) 86 | 87 | return dataset_train, dataset_val, list_classes_train, list_classes_val 88 | 89 | 90 | 91 | def load_dataset_test(data_dir, dataset, batch_size): 92 | list_classes_test = [] 93 | 94 | fas=False 95 | 96 | path = os.path.join(data_dir, 'Datasets', dataset) 97 | 98 | if dataset == 'mnist': 99 | dataset_test = datasets.MNIST(path, train=False, download=True, transform=transforms.Compose([transforms.ToTensor()])) 100 | elif dataset == 'fashion': 101 | if fas: 102 | dataset_test = DataLoader( 103 | datasets.FashionMNIST(path, train=False, download=True, transform=transforms.Compose( 104 | [transforms.ToTensor()])), 105 | batch_size=batch_size) 106 | else: 107 | dataset_test = fashion(path, train=False, download=True, transform=transforms.ToTensor()) 108 | 109 | elif dataset == 'cifar10': 110 | transform = transforms.Compose( 111 | [transforms.ToTensor(), 112 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 113 | 114 | dataset_test = datasets.CIFAR10(root=path, train=False, 115 | download=True, transform=transform) 116 | 117 | elif dataset == 'celebA': 118 | dataset_test = utils.load_celebA(path + 'celebA', transform=transforms.Compose( 119 | [transforms.CenterCrop(160), transforms.Scale(64), transforms.ToTensor()]), batch_size=batch_size) 120 | elif dataset == 'timagenet': 121 | dataset_test, labels = get_test_image_folders(path) 122 | list_classes_test = np.asarray([labels[i] for i in range(len(dataset_test))]) 123 | dataset_test = Subset(dataset_test, np.where(list_classes_test < 10)[0]) 124 | list_classes_test = np.where(list_classes_test < 10)[0] 125 | 126 | list_classes_test = np.asarray([dataset_test[i][1] for i in range(len(dataset_test))]) 127 | 128 | return dataset_test, list_classes_test 129 | 130 | 131 | def get_iter_dataset(dataset, list_classe=[], batch_size=64, classe=None): 132 | if classe is not None: 133 | dataset = Subset(dataset, np.where(list_classe == classe)[0]) 134 | 135 | data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False) 136 | 137 | return data_loader 138 | -------------------------------------------------------------------------------- /Data/main_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | # from CUB200 import CUB200_Disjoint 5 | # from cifar100 import Cifar100_Disjoint 6 | from disjoint import Disjoint 7 | from rotations import Rotations 8 | from permutations import Permutations 9 | from fashion import fashion 10 | from torchvision import datasets, transforms 11 | from torch.utils.data.dataset import Dataset 12 | 13 | import torch 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--dir', default='../Archives', help='input directory') 18 | parser.add_argument('--i', default='Data', help='input directory') 19 | parser.add_argument('--train_file', default='', help='input directory') 20 | parser.add_argument('--test_file', default='', help='input directory') 21 | 22 | parser.add_argument('--upperbound', default=False, type=bool) 23 | parser.add_argument('--task', default='disjoint', choices=['rotations', 'permutations', 24 | 'disjoint', 'cifar100', 'CUB200'], 25 | help='type of task to create', ) 26 | parser.add_argument('--dataset', default='mnist', type=str, choices=['mnist', 'fashion', 'cifar10']) 27 | parser.add_argument('--n_tasks', default=3, type=int, help='number of tasks') 28 | parser.add_argument('--seed', default=0, type=int, help='random seed') 29 | parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 30 | parser.add_argument('--imageSize', type=int, default=224, help='input batch size') 31 | parser.add_argument('--min_rot', default=0., type=float, help='minimum rotation') 32 | parser.add_argument('--max_rot', default=90., type=float, help='maximum rotation') 33 | args = parser.parse_args() 34 | 35 | torch.manual_seed(args.seed) 36 | 37 | 38 | print(str(args).replace(',', ',\n')) 39 | 40 | 41 | class Subset(Dataset): 42 | def __init__(self, dataset, indices): 43 | self.dataset = dataset 44 | self.indices = indices 45 | 46 | def __getitem__(self, idx): 47 | return self.dataset[self.indices[idx]] 48 | 49 | def __len__(self): 50 | return len(self.indices) 51 | 52 | 53 | args.i = os.path.join(args.dir, args.i) 54 | args.o = os.path.join(args.i, 'Tasks', args.dataset) 55 | args.i = os.path.join(args.i, 'Datasets', args.dataset) 56 | args.train_file = 'training.pt' 57 | args.test_file = 'test.pt' 58 | 59 | # download data if possible 60 | if args.dataset == 'mnist': 61 | datasets.MNIST(args.i, train=True, download=True, transform=transforms.ToTensor()) 62 | args.i = os.path.join(args.i, 'MNIST') 63 | elif args.dataset == 'fashion': 64 | fashion(args.i, train=True, download=True, transform=transforms.ToTensor()) 65 | elif args.dataset == 'cifar10': 66 | print("DL one later") 67 | elif args.dataset == 'cifar100': 68 | args.train_file = 'cifar100.pt' 69 | if not os.path.isdir(args.i): 70 | print('This dataset should be downloaded manually') 71 | elif args.dataset == 'CUB200': 72 | args.i = args.i = os.path.join(args.i, 'images') 73 | if not os.path.isdir(args.i): 74 | print('This dataset should be downloaded manually') 75 | 76 | if not os.path.exists(args.o): 77 | os.makedirs(args.o) 78 | 79 | args.i = os.path.join(args.i, 'processed') 80 | 81 | if args.task == 'rotations': 82 | DataFormatter = Rotations(args) 83 | elif args.task == 'permutations': 84 | DataFormatter = Permutations(args) 85 | elif args.task == 'disjoint': 86 | DataFormatter = Disjoint(args) 87 | elif args.task == 'cifar100': 88 | DataFormatter = Cifar100_Disjoint(args) 89 | elif args.task == 'CUB200': 90 | DataFormatter = CUB200_Disjoint(args) 91 | else: 92 | print("Not Implemented") 93 | 94 | DataFormatter.formating_data() 95 | -------------------------------------------------------------------------------- /Data/permutations.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path 3 | import torch 4 | 5 | class Permutations(object): 6 | def __init__(self, args): 7 | super(Permutations, self).__init__() 8 | 9 | self.n_tasks = args.n_tasks 10 | self.i = args.i 11 | self.train_file = args.train_file 12 | self.test_file = args.test_file 13 | 14 | self.o = os.path.join(self.i, self.train_file).replace('training.pt', 'permutations_' + str(self.n_tasks) + '.pt') 15 | #self.o = os.path.join(args.o, 'permutations_' + str(self.n_tasks) + '.pt') 16 | 17 | 18 | self.o_train = os.path.join(args.o, 'permutations_' + str(self.n_tasks) + '_train.pt') 19 | self.o_test = os.path.join(args.o, 'permutations_' + str(self.n_tasks) + '_test.pt') 20 | 21 | def formating_data(self): 22 | 23 | assert os.path.isfile(os.path.join(self.i, self.train_file)) 24 | assert os.path.isfile(os.path.join(self.i, self.test_file)) 25 | 26 | tasks_tr = [] 27 | tasks_te = [] 28 | 29 | x_tr, y_tr = torch.load(os.path.join(self.i, self.train_file)) 30 | x_te, y_te = torch.load(os.path.join(self.i, self.test_file)) 31 | x_tr = x_tr.float().view(x_tr.size(0), -1) / 255.0 32 | x_te = x_te.float().view(x_te.size(0), -1) / 255.0 33 | y_tr = y_tr.view(-1).long() 34 | y_te = y_te.view(-1).long() 35 | 36 | p = torch.FloatTensor(range(x_tr.size(1))).long() 37 | for t in range(self.n_tasks): 38 | 39 | tasks_tr.append(['random permutation', x_tr.index_select(1, p), y_tr]) 40 | tasks_te.append(['random permutation', x_te.index_select(1, p), y_te]) 41 | p = torch.randperm(x_tr.size(1)).long().view(-1) 42 | 43 | 44 | torch.save(tasks_tr, self.o_train) 45 | torch.save(tasks_te, self.o_test) 46 | 47 | if __name__ == '__main__': 48 | parser = argparse.ArgumentParser() 49 | 50 | parser.add_argument('--i', default='raw/cifar100.pt', help='input directory') 51 | parser.add_argument('--o', default='cifar100.pt', help='output file') 52 | parser.add_argument('--n_tasks', default=10, type=int, help='number of tasks') 53 | parser.add_argument('--seed', default=0, type=int, help='random seed') 54 | parser.add_argument('--train_file', default='', help='input directory') 55 | parser.add_argument('--test_file', default='', help='input directory') 56 | args = parser.parse_args() 57 | 58 | torch.manual_seed(args.seed) 59 | 60 | DataFormatter = Permutations(args) 61 | DataFormatter.formating_data() 62 | -------------------------------------------------------------------------------- /Data/rotations.py: -------------------------------------------------------------------------------- 1 | 2 | from torchvision import transforms 3 | from PIL import Image 4 | import argparse 5 | import os.path 6 | import random 7 | import torch 8 | 9 | class Rotations(object): 10 | def __init__(self, args): 11 | super(Rotations, self).__init__() 12 | 13 | self.n_tasks = args.n_tasks 14 | self.i = args.i 15 | self.image_size = args.imageSize 16 | self.min_rot = args.min_rot 17 | self.max_rot = args.max_rot 18 | self.train_file = args.train_file 19 | self.test_file = args.test_file 20 | self.o = os.path.join(self.i, self.train_file).replace('training.pt', 'rotations_' + str(self.n_tasks) + '.pt') 21 | 22 | self.o_train = os.path.join(args.o, 'rotations_' + str(self.n_tasks) + '_train.pt') 23 | self.o_test = os.path.join(args.o, 'rotations_' + str(self.n_tasks) + '_test.pt') 24 | 25 | def rotate_dataset(self, d, rotation): 26 | result = torch.FloatTensor(d.size(0), 784) 27 | tensor = transforms.ToTensor() 28 | 29 | for i in range(d.size(0)): 30 | img = Image.fromarray(d[i].numpy(), mode='L') 31 | result[i] = tensor(img.rotate(rotation)).view(784) 32 | return result 33 | 34 | def formating_data(self): 35 | 36 | assert os.path.isfile(os.path.join(self.i, self.train_file)) 37 | assert os.path.isfile(os.path.join(self.i, self.test_file)) 38 | 39 | tasks_tr = [] 40 | tasks_te = [] 41 | 42 | x_tr, y_tr = torch.load(os.path.join(self.i, self.train_file)) 43 | x_te, y_te = torch.load(os.path.join(self.i, self.test_file)) 44 | 45 | for t in range(self.n_tasks): 46 | min_rot = 1.0 * t / self.n_tasks * (self.max_rot - self.min_rot) + \ 47 | self.min_rot 48 | max_rot = 1.0 * (t + 1) / self.n_tasks * \ 49 | (self.max_rot - self.min_rot) + self.min_rot 50 | rot = random.random() * (max_rot - min_rot) + min_rot 51 | 52 | tasks_tr.append([rot, self.rotate_dataset(x_tr, rot), y_tr]) 53 | tasks_te.append([rot, self.rotate_dataset(x_te, rot), y_te]) 54 | 55 | torch.save(tasks_tr, self.o_train) 56 | torch.save(tasks_te, self.o_test) 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser() 60 | 61 | parser.add_argument('--i', default='raw/', help='input directory') 62 | parser.add_argument('--o', default='mnist_rotations.pt', help='output file') 63 | parser.add_argument('--n_tasks', default=10, type=int, help='number of tasks') 64 | parser.add_argument('--min_rot', default=0., 65 | type=float, help='minimum rotation') 66 | parser.add_argument('--max_rot', default=90., 67 | type=float, help='maximum rotation') 68 | parser.add_argument('--seed', default=0, type=int, help='random seed') 69 | 70 | args = parser.parse_args() 71 | torch.manual_seed(args.seed) 72 | 73 | DataFormatter = Mnist_Rotation(args) 74 | DataFormatter.formating_data() 75 | -------------------------------------------------------------------------------- /Evaluation/Eval_Classifier.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | from tqdm import tqdm 5 | 6 | import utils 7 | from Classifiers.Fashion_Classifier import Fashion_Classifier 8 | from Classifiers.Mnist_Classifier import Mnist_Classifier 9 | from Data.load_dataset import load_dataset_full, load_dataset_test, get_iter_dataset 10 | from log_utils import * 11 | from Data.data_loader import DataLoader 12 | from Evaluation.tools import calculate_frechet_distance 13 | 14 | mpl.use('Agg') 15 | 16 | 17 | class Reviewer_C(object): 18 | def __init__(self, args): 19 | # parameters 20 | self.args = args 21 | self.epoch_Review = args.epoch_Review 22 | self.sample_num = 64 23 | self.batch_size = args.batch_size 24 | self.save_dir = args.save_dir 25 | self.result_dir = args.result_dir 26 | self.sample_dir = args.sample_dir 27 | self.dataset = args.dataset 28 | self.log_dir = args.log_dir 29 | self.gpu_mode = args.gpu_mode 30 | self.model_name = args.gan_type 31 | self.data_dir = args.data_dir 32 | self.gen_dir = args.gen_dir 33 | self.verbose = args.verbose 34 | 35 | self.lr = args.lrC 36 | self.momentum = args.momentum 37 | self.log_interval = 100 38 | self.sample_num = 100 39 | self.size_epoch = args.size_epoch 40 | self.gan_type = args.gan_type 41 | self.conditional = args.conditional 42 | self.device = args.device 43 | self.TrainEval = args.TrainEval 44 | self.num_task = args.num_task 45 | self.task_type = args.task_type 46 | self.context = args.context 47 | 48 | self.seed = args.seed 49 | 50 | if self.conditional: 51 | self.model_name = 'C' + self.model_name 52 | 53 | # Load the generator parameters 54 | 55 | # The reviewer evaluate generate dataset (loader train) on true data (loader test) 56 | # not sur yet if valid should be real or not (it was before) 57 | dataset_train, dataset_valid, list_class_train, list_class_valid = load_dataset_full(self.data_dir, 58 | args.dataset) 59 | dataset_test, list_class_test = load_dataset_test(self.data_dir, args.dataset, args.batch_size) 60 | 61 | # create data loader for validation and testing 62 | self.valid_loader = get_iter_dataset(dataset_valid) 63 | self.test_loader = get_iter_dataset(dataset_test) 64 | 65 | if self.dataset == 'mnist': 66 | self.input_size = 1 67 | self.size = 28 68 | elif self.dataset == 'fashion': 69 | self.input_size = 1 70 | self.size = 28 71 | elif self.dataset == 'cifar10': 72 | self.input_size = 3 73 | self.size = 32 74 | 75 | if self.dataset == 'mnist': 76 | self.Classifier = Mnist_Classifier(args) 77 | elif self.dataset == 'fashion': 78 | self.Classifier = Fashion_Classifier(args) 79 | else: 80 | print('Not implemented') 81 | 82 | # this should be train on task 83 | def train_classifier(self, epoch, data_loader_train, ind_task): 84 | self.Classifier.net.train() 85 | 86 | train_loss_classif, train_accuracy = self.Classifier.train_on_task(data_loader_train, ind_task=ind_task, epoch=epoch, 87 | additional_loss=None) 88 | val_loss_classif, valid_accuracy, classe_prediction, classe_total, classe_wrong = self.Classifier.eval_on_task( 89 | self.valid_loader, epoch) 90 | 91 | if self.verbose: 92 | print( 93 | 'Epoch: {} Train set: Average loss: {:.4f}, Accuracy: ({:.2f}%)\n Valid set: Average loss: {:.4f}, Accuracy: ({:.2f}%)'.format( 94 | epoch, train_loss_classif, train_accuracy, val_loss_classif, valid_accuracy)) 95 | return train_loss_classif, train_accuracy, val_loss_classif, valid_accuracy, ( 96 | 100. * classe_prediction) / classe_total 97 | 98 | 99 | 100 | def review(self, data_loader_train, value): 101 | 102 | 103 | self.Classifier.reinit() 104 | 105 | best_accuracy = -1 106 | train_loss = [] 107 | train_acc = [] 108 | val_loss = [] 109 | val_acc = [] 110 | valid_acc = [] 111 | valid_acc_classes = [] 112 | 113 | print("Number of samples : " + str(value)) 114 | 115 | early_stop = 0. 116 | # Training classifier 117 | for epoch in range(self.epoch_Review): 118 | tr_loss, tr_acc, v_loss, v_acc, v_acc_classes = self.train_classifier(epoch, data_loader_train, value) 119 | train_loss.append(tr_loss) 120 | train_acc.append(tr_acc) 121 | val_loss.append(v_loss) 122 | val_acc.append(v_acc) 123 | # Save best model 124 | if v_acc > best_accuracy: 125 | if self.verbose: 126 | print("New Best Classifier") 127 | print(v_acc) 128 | best_accuracy = v_acc 129 | self.save(best=True) 130 | early_stop = 0. 131 | if early_stop == 60: 132 | break 133 | else: 134 | early_stop += 1 135 | valid_acc.append(np.array(v_acc)) 136 | valid_acc_classes.append(np.array(v_acc_classes)) 137 | 138 | # Then load best model 139 | self.load() 140 | 141 | loss, test_acc, classe_prediction, classe_total, classe_wrong = self.Classifier.eval_on_task( 142 | self.test_loader, 0) 143 | 144 | 145 | 146 | test_acc_classes = 100. * classe_prediction / classe_total 147 | 148 | if self.verbose: 149 | print('\nTest set: Average loss: {:.4f}, Accuracy : ({:.2f}%)'.format( 150 | loss, test_acc )) 151 | 152 | for i in range(10): 153 | print('Classe {} Accuracy: {}/{} ({:.3f}%, Wrong : {})'.format( 154 | i, classe_prediction[i], classe_total[i], 155 | 100. * classe_prediction[i] / classe_total[i], classe_wrong[i])) 156 | 157 | print('\n') 158 | 159 | 160 | # loss, test_acc, test_acc_classes = self.test() # self.test_classifier(epoch) 161 | np.savetxt(os.path.join(self.log_dir,'data_classif_' + self.dataset + '-num_samples_' + str(value) + '.txt'), 162 | np.transpose([train_loss, train_acc, val_loss, val_acc])) 163 | np.savetxt(os.path.join(self.log_dir,'best_score_classif_' + self.dataset + '-num_samples_' + str(value) + '.txt'), 164 | np.transpose([test_acc])) 165 | np.savetxt(os.path.join(self.log_dir,'data_classif_classes' + self.dataset + '-num_samples_' + str(value) + '.txt'), 166 | np.transpose([test_acc_classes])) 167 | 168 | 169 | return valid_acc, valid_acc_classes 170 | 171 | def review_all_tasks(self, args, list_values): 172 | for value in list_values: 173 | # create data set with value samples 174 | dataset_train, dataset_valid, list_class_train, list_class_valid = load_dataset_full(self.data_dir, 175 | args.dataset, value) 176 | 177 | data_loader_train = get_iter_dataset(dataset_train) 178 | #data_loader_train = DataLoader(dataset_train, args) 179 | 180 | self.review(data_loader_train, value) 181 | 182 | # save a classifier or the best classifier 183 | def save(self, best=False): 184 | 185 | if not os.path.exists(self.save_dir): 186 | os.makedirs(self.save_dir) 187 | 188 | if best: 189 | torch.save(self.Classifier.net.state_dict(), 190 | os.path.join(self.save_dir, self.model_name + '_Classifier_Best.pkl')) 191 | else: 192 | torch.save(self.Classifier.net.state_dict(), 193 | os.path.join(self.save_dir, self.model_name + '_Classifier.pkl')) 194 | 195 | # load the best classifier or the reference classifier trained on true data only 196 | def load(self, reference=False): 197 | if reference: 198 | save_dir = os.path.join(self.save_dir, "..", "..", "..", "Classifier", 'seed_' + str(self.seed)) 199 | self.Classifier.net.load_state_dict(torch.load(os.path.join(save_dir, 'Classifier_Classifier_Best.pkl'))) 200 | else: 201 | self.Classifier.net.load_state_dict( 202 | torch.load(os.path.join(self.save_dir, self.model_name + '_Classifier_Best.pkl'))) 203 | 204 | def load_best_baseline(self): 205 | 206 | # best seed searched in the list define in get_best_baseline function, liste_seed = [1, 2, 3, 4, 5, 6, 7, 8] 207 | best_seed = utils.get_best_baseline(self.log_dir, self.dataset) 208 | 209 | save_dir = os.path.join(self.save_dir, "..", "..", "..", "Classifier", 'seed_' + str(best_seed)) 210 | self.Classifier.net.load_state_dict(torch.load(os.path.join(save_dir, 'Classifier_Classifier_Best.pkl'))) 211 | -------------------------------------------------------------------------------- /Evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TLESORT/Generative_Continual_Learning/66b121437c248993b41f154b5a2d6b7197278578/Evaluation/__init__.py -------------------------------------------------------------------------------- /Evaluation/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy import linalg 4 | from scipy.stats import entropy 5 | #from sklearn.neighbors import KNeighborsClassifier 6 | from torch.autograd import Variable 7 | 8 | from log_utils import * 9 | 10 | mpl.use('Agg') 11 | 12 | import warnings 13 | 14 | 15 | 16 | 17 | 18 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 19 | # token from https://github.com/bioinf-jku/TTUR/blob/master/fid.py 20 | 21 | """Numpy implementation of the Frechet Distance. 22 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 23 | and X_2 ~ N(mu_2, C_2) is 24 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 25 | 26 | Stable version by Dougal J. Sutherland. 27 | Params: 28 | -- mu1 : Numpy array containing the activations of the pool_3 layer of the 29 | inception net ( like returned by the function 'get_predictions') 30 | for generated samples. 31 | -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted 32 | on an representive data set. 33 | -- sigma1: The covariance matrix over activations of the pool_3 layer for 34 | generated samples. 35 | -- sigma2: The covariance matrix over activations of the pool_3 layer, 36 | precalcualted on an representive data set. 37 | Returns: 38 | -- : The Frechet Distance. 39 | """ 40 | 41 | mu1 = np.atleast_1d(mu1) 42 | mu2 = np.atleast_1d(mu2) 43 | 44 | sigma1 = np.atleast_2d(sigma1) 45 | sigma2 = np.atleast_2d(sigma2) 46 | 47 | assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" 48 | assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" 49 | 50 | diff = mu1 - mu2 51 | 52 | # product might be almost singular 53 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 54 | if not np.isfinite(covmean).all(): 55 | msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps 56 | warnings.warn(msg) 57 | offset = np.eye(sigma1.shape[0]) * eps 58 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 59 | 60 | # numerical error might give slight imaginary component 61 | if np.iscomplexobj(covmean): 62 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 63 | m = np.max(np.abs(covmean.imag)) 64 | #raise ValueError("Imaginary component {}".format(m)) 65 | print('FID is fucked up') 66 | covmean = covmean.real 67 | 68 | tr_covmean = np.trace(covmean) 69 | 70 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 71 | 72 | def Inception_score(self): 73 | 74 | eval_size = 500 75 | 76 | # 0. load reference classifier 77 | self.load_best_baseline() #we load the best classifier 78 | 79 | # 1. generate data 80 | 81 | self.Classifier.eval() 82 | 83 | output_table = torch.Tensor(eval_size * self.batch_size, 10) 84 | 85 | # compute IS on real data 86 | if self.tau == 0: 87 | if len(self.test_loader) < eval_size: 88 | output_table = torch.Tensor((len(self.test_loader) - 1) * self.batch_size, 10) 89 | print("Computing of IS on test data") 90 | for i, (data, target) in enumerate(self.test_loader): 91 | if i >= eval_size or i >= (len(self.test_loader) - 1): # (we throw away the last batch) 92 | break 93 | if self.gpu_mode: 94 | data, target = data.cuda(self.device), target.cuda(self.device) 95 | batch = Variable(data) 96 | label = Variable(target.squeeze()) 97 | classif = self.Classifier(batch) 98 | output_table[i * self.batch_size:(i + 1) * self.batch_size, :] = classif.data 99 | elif self.tau == -1: 100 | if len(self.train_loader) < eval_size: 101 | output_table = torch.Tensor((len(self.train_loader) - 1) * self.batch_size, 10) 102 | print("Computing of IS on train data") 103 | for i, (data, target) in enumerate(self.train_loader): 104 | if i >= eval_size or i >= (len(self.train_loader) - 1): # (we throw away the last batch) 105 | break 106 | if self.gpu_mode: 107 | data, target = data.cuda(self.device), target.cuda(self.device) 108 | batch = Variable(data) 109 | label = Variable(target.squeeze()) 110 | classif = self.Classifier(batch) 111 | output_table[i * self.batch_size:(i + 1) * self.batch_size, :] = classif.data 112 | else: 113 | print("Computing of IS on generated data") 114 | for i in range(eval_size): 115 | data, target = self.generator.sample(self.batch_size) 116 | # 2. use the reference classifier to compute the output vector 117 | if self.gpu_mode: 118 | data, target = data.cuda(self.device), target.cuda(self.device) 119 | batch = Variable(data) 120 | label = Variable(target.squeeze()) 121 | classif = self.Classifier(batch) 122 | 123 | output_table[i * self.batch_size:(i + 1) * self.batch_size, :] = classif.data 124 | 125 | # Now compute the mean kl-div 126 | py = output_table.mean(0) 127 | 128 | assert py.shape[0] == 10 129 | 130 | scores = [] 131 | for i in range(output_table.shape[0]): 132 | pyx = output_table[i, :] 133 | assert pyx.shape[0] == py.shape[0] 134 | scores.append(entropy(pyx.tolist(), py.tolist())) # compute the KL-Divergence KL(P(Y|X)|P(Y)) 135 | Inception_score = np.exp(np.asarray(scores).mean()) 136 | 137 | if self.tau == 0: 138 | print("save reference IS") 139 | log_dir = os.path.join(self.log_dir, "..", "..", "..", "Classifier", 'seed_' + str(self.seed)) 140 | np.savetxt(os.path.join(os.path.join(log_dir, 'Inception_score_ref_' + self.dataset + '.txt')), 141 | np.transpose([Inception_score])) 142 | elif self.tau == -1: 143 | print("save IS evaluate on train") 144 | log_dir = os.path.join(self.log_dir, "..", "..", "..", "Classifier", 'seed_' + str(self.seed)) 145 | np.savetxt(os.path.join(os.path.join(log_dir, 'Inception_score_train_' + self.dataset + '.txt')), 146 | np.transpose([Inception_score])) 147 | else: 148 | np.savetxt(os.path.join(self.log_dir, 'Inception_score_' + self.dataset + '.txt'), 149 | np.transpose([Inception_score])) 150 | 151 | print("Inception Score") 152 | print(Inception_score) 153 | 154 | 155 | def knn(self): 156 | print("Training KNN Classifier") 157 | # Declare Classifier model 158 | data_samples = [] 159 | label_samples = [] 160 | 161 | # Training knn 162 | neigh = KNeighborsClassifier(n_neighbors=1) 163 | # We get the test data 164 | for i, (d, t) in enumerate(self.test_loader): 165 | if i == 0: 166 | data_test = d 167 | label_test = t 168 | else: 169 | data_test = torch.cat((data_test, d)) 170 | label_test = torch.cat((label_test, t)) 171 | data_test = data_test.numpy().reshape(-1, 784) 172 | label_test = label_test.numpy() 173 | # We get the training data 174 | for i, (d, t) in enumerate(self.train_loader): 175 | if i == 0: 176 | data_train = d 177 | label_train = t 178 | else: 179 | data_train = torch.cat((data_train, d)) 180 | label_train = torch.cat((label_train, t)) 181 | data = data_train.numpy().reshape(-1, 784) 182 | labels = label_train.numpy() 183 | 184 | if self.tau > 0: 185 | # we reduce the dataset 186 | data = data[0:int(len(data_train) * (1 - self.tau))] 187 | labels = labels[0:int(len(data_train) * (1 - self.tau))] 188 | # We get samples from the models 189 | for i in range(int((label_train.shape[0] * self.tau) / self.batch_size)): 190 | data_gen, label_gen = self.generator.sample(self.batch_size) 191 | data_samples.append(data_gen.cpu().numpy()) 192 | label_samples.append(label_gen.cpu().numpy()) 193 | 194 | # We concatenate training and gen samples 195 | data_samples = np.concatenate(data_samples).reshape(-1, 784) 196 | label_samples = np.concatenate(label_samples).squeeze() 197 | data = np.concatenate([data, data_samples]) 198 | labels = np.concatenate([labels, label_samples]) 199 | 200 | # We train knn 201 | neigh.fit(data, labels) 202 | accuracy = neigh.score(data_test,label_test) 203 | print("accuracy=%.2f%%" % (accuracy * 100)) 204 | 205 | 206 | np.savetxt(os.path.join(self.log_dir, 'best_score_knn_' + self.dataset + '-tau' + str(self.tau) + '.txt'), 207 | np.transpose([accuracy])) -------------------------------------------------------------------------------- /Generative_Models/BEGAN.py: -------------------------------------------------------------------------------- 1 | import utils, torch, time, os 2 | import numpy as np 3 | from torch.autograd import Variable 4 | from Generative_Models.Generative_Model import GenerativeModel 5 | from Data.load_dataset import get_iter_dataset 6 | 7 | 8 | 9 | 10 | class BEGAN(GenerativeModel): 11 | def __init__(self, args): 12 | super(BEGAN, self).__init__(args) 13 | self.gamma = 0.75 14 | self.lambda_ = 0.001 15 | self.k = 0. 16 | 17 | def train_on_task(self, train_loader, ind_task, epoch, additional_loss): 18 | self.size_epoch = 1000 19 | 20 | if self.gpu_mode: 21 | self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda(self.device)), Variable( 22 | torch.zeros(self.batch_size, 1).cuda(self.device)) 23 | else: 24 | self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable( 25 | torch.zeros(self.batch_size, 1)) 26 | 27 | self.G.train() 28 | self.D.train() 29 | 30 | epoch_start_time = time.time() 31 | sum_loss_train = 0. 32 | 33 | for iter, (x_, t_) in enumerate(train_loader): 34 | 35 | if x_.size(0) != self.batch_size: 36 | break 37 | 38 | x_ = x_.view((-1, 1, 28, 28)) 39 | z_ = torch.rand((self.batch_size, self.z_dim)) 40 | 41 | if self.gpu_mode: 42 | x_, z_ = Variable(x_.cuda()), Variable(z_.cuda()) 43 | else: 44 | x_, z_ = Variable(x_), Variable(z_) 45 | 46 | # update D network 47 | self.D_optimizer.zero_grad() 48 | 49 | D_real = self.D(x_) 50 | D_real_err = torch.mean(torch.abs(D_real - x_)) 51 | 52 | G_ = self.G(z_) 53 | D_fake = self.D(G_) 54 | D_fake_err = torch.mean(torch.abs(D_fake - G_)) 55 | 56 | D_loss = D_real_err - self.k * D_fake_err 57 | self.train_hist['D_loss'].append(D_loss.data[0]) 58 | 59 | D_loss.backward() 60 | self.D_optimizer.step() 61 | 62 | # update G network 63 | self.G_optimizer.zero_grad() 64 | 65 | G_ = self.G(z_) 66 | D_fake = self.D(G_) 67 | D_fake_err = torch.mean(torch.abs(D_fake - G_)) 68 | 69 | G_loss = D_fake_err 70 | self.train_hist['G_loss'].append(G_loss.data[0]) 71 | 72 | G_loss.backward() 73 | self.G_optimizer.step() 74 | 75 | # convergence metric 76 | temp_M = D_real_err + torch.abs(self.gamma * D_real_err - D_fake_err) 77 | 78 | # operation for updating k 79 | temp_k = self.k + self.lambda_ * (self.gamma * D_real_err - D_fake_err) 80 | temp_k = temp_k.data[0] 81 | 82 | # self.k = temp_k.data[0] 83 | self.k = min(max(temp_k, 0), 1) 84 | self.M = temp_M.data[0] 85 | 86 | if self.verbose: 87 | if ((iter + 1) % 100) == 0: 88 | print("Ind_task : [%1d] Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f, M: %.8f, k: %.8f" % 89 | (ind_task, (epoch + 1), (iter + 1), self.size_epoch, 90 | D_loss.data[0], G_loss.data[0], self.M, self.k)) 91 | 92 | # the following line is probably wrong 93 | self.train_hist['total_time'].append(time.time() - epoch_start_time) 94 | 95 | 96 | if self.verbose: 97 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), 98 | self.epoch, 99 | self.train_hist['total_time'][0])) 100 | print("Training finish!... save training results") 101 | 102 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) 103 | self.save() 104 | 105 | sum_loss_train = sum_loss_train / np.float(len(train_loader)) 106 | 107 | return sum_loss_train 108 | 109 | ''' 110 | def train(self): 111 | 112 | self.G.apply(self.G.weights_init) 113 | self.D.train() 114 | 115 | for classe in range(10): 116 | self.train_hist = {} 117 | self.train_hist['D_loss'] = [] 118 | self.train_hist['G_loss'] = [] 119 | self.train_hist['per_epoch_time'] = [] 120 | self.train_hist['total_time'] = [] 121 | # self.G.apply(self.G.weights_init) does not work for instance 122 | 123 | if self.gpu_mode: 124 | self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda()), Variable(torch.zeros(self.batch_size, 1).cuda()) 125 | else: 126 | self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable(torch.zeros(self.batch_size, 1)) 127 | 128 | self.D.train() 129 | self.data_loader_train = get_iter_dataset(self.dataset_train, self.list_class_train, self.batch_size, 130 | classe) 131 | self.data_loader_valid = get_iter_dataset(self.dataset_valid, self.list_class_valid, self.batch_size, 132 | classe) 133 | print('training class : ' + str(classe)) 134 | start_time = time.time() 135 | for epoch in range(self.epoch): 136 | self.G.train() 137 | epoch_start_time = time.time() 138 | n_batch = 0. 139 | 140 | for iter, (x_, t_) in enumerate(self.data_loader_train): 141 | n_batch += 1 142 | z_ = torch.rand((self.batch_size, self.z_dim)) 143 | 144 | if self.gpu_mode: 145 | x_, z_ = Variable(x_.cuda()), Variable(z_.cuda()) 146 | else: 147 | x_, z_ = Variable(x_), Variable(z_) 148 | 149 | # update D network 150 | self.D_optimizer.zero_grad() 151 | 152 | D_real = self.D(x_) 153 | D_real_err = torch.mean(torch.abs(D_real - x_)) 154 | 155 | G_ = self.G(z_) 156 | D_fake = self.D(G_) 157 | D_fake_err = torch.mean(torch.abs(D_fake - G_)) 158 | 159 | D_loss = D_real_err - self.k * D_fake_err 160 | self.train_hist['D_loss'].append(D_loss.data[0]) 161 | 162 | D_loss.backward() 163 | self.D_optimizer.step() 164 | 165 | # update G network 166 | self.G_optimizer.zero_grad() 167 | 168 | G_ = self.G(z_) 169 | D_fake = self.D(G_) 170 | D_fake_err = torch.mean(torch.abs(D_fake - G_)) 171 | 172 | G_loss = D_fake_err 173 | self.train_hist['G_loss'].append(G_loss.data[0]) 174 | 175 | G_loss.backward() 176 | self.G_optimizer.step() 177 | 178 | # convergence metric 179 | temp_M = D_real_err + torch.abs(self.gamma * D_real_err - D_fake_err) 180 | 181 | # operation for updating k 182 | temp_k = self.k + self.lambda_ * (self.gamma * D_real_err - D_fake_err) 183 | temp_k = temp_k.data[0] 184 | 185 | # self.k = temp_k.data[0] 186 | self.k = min(max(temp_k, 0), 1) 187 | self.M = temp_M.data[0] 188 | 189 | if ((iter + 1) % 100) == 0: 190 | print("classe : [%1d] Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f, M: %.8f, k: %.8f" % 191 | (classe, (epoch + 1), (iter + 1), self.size_epoch, 192 | D_loss.data[0], G_loss.data[0], self.M, self.k)) 193 | 194 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) 195 | self.visualize_results((epoch+1), classe) 196 | 197 | self.save_G(classe) 198 | 199 | result_dir = self.result_dir + '/' + 'classe-' + str(classe) 200 | utils.generate_animation(result_dir + '/' + self.model_name, epoch + 1) 201 | utils.loss_plot(self.train_hist, result_dir, self.model_name) 202 | 203 | np.savetxt( 204 | os.path.join(result_dir, 'began_training_' + self.dataset + '.txt'), 205 | np.transpose([self.train_hist['G_loss']])) 206 | 207 | self.train_hist['total_time'].append(time.time() - start_time) 208 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), 209 | self.epoch, self.train_hist['total_time'][0])) 210 | print("Training finish!... save training results") 211 | 212 | ''' -------------------------------------------------------------------------------- /Generative_Models/CGAN.py: -------------------------------------------------------------------------------- 1 | import utils, torch, time 2 | import numpy as np 3 | from torch.autograd import Variable 4 | from Generative_Models.Conditional_Model import ConditionalModel 5 | from Data.load_dataset import get_iter_dataset 6 | from torch.utils.data import DataLoader 7 | 8 | from utils import variable 9 | 10 | import math 11 | 12 | 13 | class CGAN(ConditionalModel): 14 | 15 | 16 | def run_batch(self, x_, t_, additional_loss=None): 17 | 18 | for p in self.D.parameters(): # reset requires_grad 19 | p.requires_grad = True # they are set to False below in netG update 20 | 21 | self.G.train() 22 | self.D.train() 23 | 24 | x_ = x_.view((-1, 1, 28, 28)) 25 | y_onehot = variable(self.get_one_hot(t_)) 26 | z_ = variable(torch.rand((x_.size(0), self.z_dim))) 27 | 28 | # update D network 29 | self.D_optimizer.zero_grad() 30 | 31 | D_real = self.D(x_, y_onehot) 32 | 33 | D_real_loss = self.BCELoss(D_real[0], self.y_real_[:x_.size(0)]) 34 | 35 | G_ = self.G(z_, y_onehot) 36 | 37 | D_fake = self.D(G_, y_onehot) 38 | D_fake_loss = self.BCELoss(D_fake[0], self.y_fake_[:x_.size(0)]) 39 | 40 | D_loss = D_real_loss + D_fake_loss 41 | 42 | D_loss.backward() 43 | self.D_optimizer.step() 44 | 45 | for p in self.D.parameters(): # reset requires_grad 46 | p.requires_grad = False # they are set to False below in netG update 47 | 48 | # update G network 49 | self.G_optimizer.zero_grad() 50 | 51 | G_ = self.G(z_, y_onehot) 52 | D_fake = self.D(G_, y_onehot) 53 | G_loss = self.BCELoss(D_fake[0], self.y_real_[:x_.size(0)]) 54 | 55 | 56 | if additional_loss is not None: 57 | regularization = additional_loss(self) 58 | G_loss += regularization 59 | 60 | G_loss.backward() 61 | self.G_optimizer.step() 62 | return G_loss.item() 63 | 64 | def train_on_task(self, train_loader, ind_task, epoch, additional_loss): 65 | 66 | self.G.train() 67 | self.D.train() 68 | epoch_start_time = time.time() 69 | 70 | sum_loss_train = 0. 71 | 72 | 73 | for iter, (x_, t_) in enumerate(train_loader): 74 | 75 | 76 | if self.num_task==10 and ind_task != self.num_task : 77 | 78 | # An image can be wrongly labelled by a label from futur task 79 | # it is not a ethical problem and it help learning 80 | 81 | # if ind_task != self.num_task, there no more future task, nothing to do 82 | 83 | # the following line produce a vector of batch_size with label from ind_task to self.num_task-1 84 | rand_t_ = ((torch.randperm(1000) % (self.num_task-ind_task)).long() + ind_task)[:x_.size(0)] 85 | mask = (t_ == ind_task).long() 86 | 87 | # if we are in a past task we keep the true label 88 | # else we put a random label from the futur label 89 | t_ = torch.mul(t_, 1 - mask) + torch.mul(rand_t_, mask) 90 | 91 | x_ = variable(x_.view((-1, 1, 28, 28))) 92 | sum_loss_train+=self.run_batch(x_, t_, additional_loss=None) 93 | 94 | 95 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) 96 | self.save() 97 | 98 | sum_loss_train = sum_loss_train / np.float(len(train_loader)) 99 | 100 | return sum_loss_train -------------------------------------------------------------------------------- /Generative_Models/CVAE.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | from Generative_Models.VAE import VAE 8 | from Generative_Models.Conditional_Model import ConditionalModel 9 | 10 | from utils import variable 11 | 12 | 13 | class CVAE(ConditionalModel, VAE): 14 | def __init__(self, args): 15 | super(CVAE, self).__init__(args) 16 | 17 | def train_on_task(self, train_loader, ind_task, epoch, additional_loss): 18 | 19 | self.E.train() 20 | self.G.train() 21 | sum_loss_train = 0. 22 | n_batch = 0. 23 | 24 | for iter, (x_, t_) in enumerate(train_loader): 25 | 26 | x_ = variable(x_) 27 | y_onehot = variable(self.get_one_hot(t_)) 28 | 29 | self.E_optimizer.zero_grad() 30 | self.G_optimizer.zero_grad() 31 | # VAE 32 | z_, mu, logvar = self.E(x_, y_onehot) 33 | recon_batch = self.G(z_, y_onehot) 34 | 35 | G_loss = self.loss_function(recon_batch, x_.view(-1,1,28,28), mu, logvar) 36 | sum_loss_train += G_loss.item() 37 | 38 | #regularization = additional_loss([self.G, self.E]) 39 | regularization = additional_loss(self) 40 | #regularization = additional_loss(self.E) 41 | 42 | if regularization is not None: 43 | G_loss += regularization 44 | 45 | G_loss.backward() 46 | self.E_optimizer.step() 47 | self.G_optimizer.step() 48 | 49 | n_batch += 1 50 | 51 | if self.verbose: 52 | if ((iter + 1) % 100) == 0: 53 | print("Task : [%1d] Epoch: [%2d] [%4d/%4d] G_loss: %.8f, E_loss: %.8f" % 54 | (ind_task, (epoch + 1), (iter + 1), self.size_epoch, G_loss.item(), G_loss.item())) 55 | 56 | sum_loss_train = sum_loss_train / np.float(n_batch) 57 | return sum_loss_train 58 | -------------------------------------------------------------------------------- /Generative_Models/Conditional_Model.py: -------------------------------------------------------------------------------- 1 | from Generative_Models.Generative_Model import GenerativeModel 2 | import torch 3 | from utils import * 4 | 5 | 6 | class ConditionalModel(GenerativeModel): 7 | 8 | 9 | # if no task2generate are given we generate all labellize for all task 10 | # if task2generate and annotate == false we generate only for the actual task 11 | # if task2generate and annotate == true we generate only for all past tasks 12 | def sample(self, batch_size, task2generate=None, multi_annotation=False): 13 | ''' 14 | :param batch_size: 15 | :param task2generate: give the index of class to generate (the name is a bit misleading) 16 | :param multi_annotation: indicate if we want just one classes or all classes <= task2generate 17 | :param expert: classifier that can give a label to samples 18 | :return: batch of sample from different classes and return a batch of images and label 19 | ''' 20 | 21 | self.G.eval() 22 | 23 | if task2generate is not None: 24 | classes2generate=task2generate + 1 25 | else: 26 | classes2generate=self.num_classes 27 | 28 | z_ = self.random_tensor(batch_size, self.z_dim) 29 | if multi_annotation: 30 | # keep this please 31 | # y = torch.LongTensor(batch_size, 1).random_() % self.num_classes 32 | y = (torch.randperm(batch_size * 10) % classes2generate)[:batch_size] 33 | y_onehot = self.get_one_hot(y) 34 | else: 35 | y = (torch.ones(batch_size) * (classes2generate-1)).long() 36 | y_onehot = self.get_one_hot(y).cuda() 37 | 38 | output = self.G(variable(z_), y_onehot).data 39 | 40 | return output, y 41 | 42 | # For conditional Replay we generate tasks one by one 43 | def generate_batch4Task(self, nb_sample_train, task2generate, multi_annotation): 44 | return self.sample(batch_size=nb_sample_train, task2generate=task2generate, multi_annotation=False) 45 | 46 | 47 | 48 | def get_one_hot(self, y): 49 | y_onehot = torch.FloatTensor(y.shape[0], self.num_classes) 50 | y_onehot.zero_() 51 | y_onehot.scatter_(1, y[:, np.newaxis], 1.0) 52 | 53 | return y_onehot 54 | 55 | 56 | # This function generate a dataset for one class or for all class until ind_task included 57 | def generate_dataset(self, ind_task, nb_sample_per_task, one_task=True, Train=True, classe2generate=None): 58 | 59 | # to generate 10 classes classe2generate is 9 as classes 0 to 9 60 | if classe2generate is not None: 61 | assert classe2generate <= self.num_classes 62 | if self.task_type != "disjoint": 63 | assert classe2generate == self.num_classes 64 | else: 65 | classe2generate = ind_task+1 66 | 67 | train_loader_gen=None 68 | 69 | if Train: 70 | path = os.path.join(self.gen_dir, 'train_Task_' + str(ind_task) + '.pt') 71 | path_samples = os.path.join(self.sample_dir, 'samples_train_' + str(ind_task) + '.png') 72 | else: 73 | path = os.path.join(self.gen_dir, 'test_Task_' + str(ind_task) + '.pt') 74 | path_samples = os.path.join(self.sample_dir, 'samples_test_' + str(ind_task) + '.png') 75 | 76 | # if we have only on task to generate 77 | if one_task or classe2generate == 0: # generate only for the task ind_task 78 | 79 | train_loader_gen = self.generate_task(nb_sample_per_task, multi_annotation=False, classe2generate=classe2generate) 80 | 81 | else: # else case we generate for all previous task 82 | 83 | for i in range(classe2generate): # we take from all task, actual one included 84 | 85 | train_loader_ind = self.generate_task(nb_sample_per_task, multi_annotation=True, classe2generate=i) 86 | 87 | if i == 0: 88 | train_loader_gen = train_loader_ind 89 | else: 90 | train_loader_gen.concatenate(train_loader_ind) 91 | 92 | # we save the concatenation of all generated with the actual task for train and test 93 | train_loader_gen.save(path) 94 | train_loader_gen.visualize_sample(path_samples, self.sample_num, [self.size, self.size, self.input_size]) 95 | 96 | # return the the train loader with all data 97 | return train_loader_gen # test_loader_gen # for instance we don't use the test set 98 | 99 | # this generation only works for Baseline, disjoint 100 | # we generate the dataset based on one generator by task to get normally the best generated dataset 101 | # can be used to generate train or test data 102 | def generate_best_dataset(self, ind_task, nb_sample_per_task, one_task=True, Train=True, classe2generate=None): 103 | 104 | 105 | # to generate 10 classes classe2generate is 9 as classes 0 to 9 106 | if classe2generate is not None: 107 | assert classe2generate <= self.num_classes 108 | if self.task_type != "disjoint": 109 | assert classe2generate == self.num_classes 110 | else: 111 | classe2generate = ind_task+1 112 | 113 | if Train: 114 | path = os.path.join(self.gen_dir, 'Best_train_Task_' + str(ind_task) + '.pt') 115 | else: 116 | path = os.path.join(self.gen_dir, 'Best_test_Task_' + str(ind_task) + '.pt') 117 | 118 | # if we have only on task to generate 119 | if one_task or classe2generate == 0: # generate only for the task ind_task 120 | # we do not need automatic annotation since we have one generator by class 121 | previous_data_train = self.generate_task(nb_sample_per_task, multi_annotation=False, classe2generate=classe2generate) 122 | #previous_data_train = DataLoader(tasks_tr, self.args) 123 | 124 | else: # else we load the previous dataset and add the new data 125 | 126 | previous_path_train = os.path.join(self.gen_dir, 'Best_train_Task_' + str(ind_task - 1) + '.pt') 127 | 128 | previous_data_train = DataLoader(torch.load(previous_path_train), self.args) 129 | 130 | # we do not need automatic annotation since we have one generator by class 131 | train_loader_ind = self.generate_task(nb_sample_per_task, multi_annotation=False, classe2generate=i) 132 | 133 | previous_data_train.concatenate(train_loader_ind) 134 | 135 | # we save the concatenation of all generated with the actual task for train and test 136 | previous_data_train.save(path) 137 | 138 | # return nothing -------------------------------------------------------------------------------- /Generative_Models/GAN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import torch 4 | from torch.autograd import Variable 5 | from utils import variable 6 | 7 | from Generative_Models.Generative_Model import GenerativeModel 8 | 9 | 10 | class GAN(GenerativeModel): 11 | 12 | 13 | def train_on_task(self, train_loader, ind_task, epoch, additional_loss): 14 | self.size_epoch = 1000 15 | 16 | self.G.train() 17 | self.D.train() 18 | 19 | epoch_start_time = time.time() 20 | sum_loss_train = 0. 21 | 22 | 23 | for iter, (x_, t_ ) in enumerate(train_loader): 24 | 25 | 26 | if x_.size(0) != self.batch_size: 27 | break 28 | 29 | x_ = x_.view((-1, self.input_size, self.size, self.size)) 30 | 31 | z_ = torch.rand((x_.size(0), self.z_dim)) 32 | 33 | x_, z_ = variable(x_), variable(z_) 34 | 35 | # update D network 36 | self.D_optimizer.zero_grad() 37 | 38 | D_real = self.D(x_) 39 | D_real_loss = self.BCELoss(D_real, self.y_real_[:x_.size(0)]) 40 | 41 | G_ = self.G(z_) 42 | D_fake = self.D(G_) 43 | D_fake_loss = self.BCELoss(D_fake, self.y_fake_[:x_.size(0)]) 44 | 45 | D_loss = D_real_loss + D_fake_loss 46 | self.train_hist['D_loss'].append(D_loss.item()) 47 | 48 | D_loss.backward() 49 | self.D_optimizer.step() 50 | 51 | # update G network 52 | self.G_optimizer.zero_grad() 53 | 54 | G_ = self.G(z_) 55 | D_fake = self.D(G_) 56 | G_loss = self.BCELoss(D_fake, self.y_real_[:x_.size(0)]) 57 | self.train_hist['G_loss'].append(G_loss.item()) 58 | sum_loss_train += G_loss.item() 59 | 60 | regularization = additional_loss(self) 61 | 62 | if regularization is not None: 63 | G_loss += regularization 64 | 65 | G_loss.backward() 66 | self.G_optimizer.step() 67 | 68 | if self.verbose: 69 | if ((iter + 1) % 100) == 0: 70 | print("classe : [%1d] Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % 71 | (ind_task, (epoch + 1), (iter + 1), len(train_loader), D_loss.data[0], G_loss.data[0])) 72 | 73 | 74 | #the following line is probably wrong 75 | self.train_hist['total_time'].append(time.time() - epoch_start_time) 76 | if self.verbose: 77 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), 78 | self.epoch, self.train_hist['total_time'][0])) 79 | print("Training finish!... save training results") 80 | 81 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) 82 | self.save() 83 | 84 | sum_loss_train = sum_loss_train / np.float(len(train_loader)) 85 | 86 | return sum_loss_train 87 | 88 | -------------------------------------------------------------------------------- /Generative_Models/Generative_Model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import pickle 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torchvision.utils import save_image 8 | import numpy as np 9 | 10 | from Evaluation.Reviewer import * 11 | from Generative_Models.discriminator import Discriminator, Discriminator_Cifar 12 | from Generative_Models.generator import Generator, Generator_Cifar 13 | from log_utils import save_images 14 | from utils import variable 15 | from copy import deepcopy 16 | 17 | from Classifiers.Cifar_Classifier import Cifar_Classifier 18 | 19 | class GenerativeModel(object): 20 | def __init__(self, args): 21 | 22 | self.args = args 23 | 24 | # parameters 25 | self.epoch = args.epoch_G 26 | self.sample_num = 100 27 | self.batch_size = args.batch_size 28 | self.dataset = args.dataset 29 | self.gpu_mode = args.gpu_mode 30 | self.model_name = args.gan_type 31 | self.conditional = args.conditional 32 | self.seed = args.seed 33 | self.generators = [] 34 | self.c_criterion = nn.NLLLoss() 35 | self.size_epoch = args.size_epoch 36 | self.BCELoss = nn.BCELoss() 37 | self.device = args.device 38 | self.verbose = args.verbose 39 | 40 | self.save_dir = args.save_dir 41 | self.result_dir = args.result_dir 42 | self.data_dir = args.data_dir 43 | self.log_dir = args.log_dir 44 | self.gen_dir = args.gen_dir 45 | self.sample_dir = args.sample_dir 46 | 47 | self.task_type = args.task_type 48 | self.num_task = args.num_task 49 | self.num_classes = args.num_classes 50 | 51 | if self.dataset == 'mnist' or self.dataset == 'fashion': 52 | if self.model_name == 'VAE' or self.model_name == 'CVAE': 53 | self.z_dim = 20 54 | else: 55 | self.z_dim = 62 56 | self.input_size = 1 57 | self.size = 28 58 | elif self.dataset == 'cifar10': 59 | self.z_dim = 100 60 | self.input_size = 3 61 | self.size = 32 62 | 63 | if self.verbose: 64 | print("create G and D") 65 | 66 | if self.dataset=='cifar10': 67 | self.G = Generator_Cifar(self.z_dim, self.conditional) 68 | self.D = Discriminator_Cifar(self.conditional) 69 | else: 70 | self.G = Generator(self.z_dim, self.dataset, self.conditional, self.model_name) 71 | self.D = Discriminator(self.dataset, self.conditional, self.model_name) 72 | 73 | if self.verbose: 74 | print("create G and D 's optimizers") 75 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) 76 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) 77 | 78 | if self.gpu_mode: 79 | self.G=self.G.cuda(self.device) 80 | self.D=self.D.cuda(self.device) 81 | 82 | if self.verbose: 83 | print('---------- Networks architecture -------------') 84 | utils.print_network(self.G) 85 | utils.print_network(self.D) 86 | print('-----------------------------------------------') 87 | 88 | # fixed noise 89 | #self.sample_z_ = variable(torch.rand((self.sample_num, self.z_dim, 1, 1)), volatile=True) 90 | self.sample_z_ = variable(self.random_tensor(self.sample_num, self.z_dim)) 91 | 92 | if self.dataset == 'mnist': 93 | self.Classifier = Mnist_Classifier(self.args) 94 | elif self.dataset == 'fashion': 95 | self.Classifier = Fashion_Classifier(self.args) 96 | elif self.dataset == 'cifar10': 97 | self.Classifier = Cifar_Classifier(self.args) 98 | 99 | if self.gpu_mode: 100 | self.Classifier.net = self.Classifier.net.cuda(self.device) 101 | 102 | self.expert = copy.deepcopy(self.Classifier) 103 | self.expert.load_expert() 104 | 105 | # Logs 106 | self.train_hist = {} 107 | self.train_hist['D_loss'] = [] 108 | self.train_hist['G_loss'] = [] 109 | self.train_hist['per_epoch_time'] = [] 110 | self.train_hist['total_time'] = [] 111 | 112 | 113 | # usefull for all GAN 114 | self.y_real_ = variable(torch.ones(self.batch_size, 1)) 115 | self.y_fake_ = variable(torch.zeros(self.batch_size, 1)) 116 | 117 | 118 | def test(self, predict, labels): 119 | correct = 0 120 | pred = predict.data.max(1)[1] 121 | correct = pred.eq(labels.data).cpu().sum() 122 | return correct, len(labels.data) 123 | 124 | def random_tensor(self, batch_size, z_dim): 125 | # Uniform distribution 126 | return torch.rand((batch_size, z_dim, 1, 1)) 127 | 128 | # produce sample from one generator for visual inspection of a generator during training 129 | def visualize_results(self, epoch, classe=None, fix=True): 130 | 131 | sample_size=100 132 | 133 | # index allows, if there 5 task, to plot 2 classes for first task 134 | index = int(self.num_classes / self.num_task) * (classe + 1) 135 | 136 | self.G.eval() 137 | dir_path = self.result_dir 138 | if classe is not None: 139 | dir_path = self.result_dir + '/classe-' + str(classe) 140 | 141 | if not os.path.exists(dir_path): 142 | os.makedirs(dir_path) 143 | 144 | image_frame_dim = int(np.floor(np.sqrt(self.sample_num))) 145 | if self.conditional: 146 | 147 | 148 | y = torch.LongTensor(range(self.sample_num)) % self.num_classes 149 | y=y.view(self.sample_num, 1) 150 | 151 | y_onehot = torch.FloatTensor(self.sample_num, self.num_classes) 152 | y_onehot.zero_() 153 | y_onehot.scatter_(1, y, 1.0) 154 | y_onehot = variable(y_onehot) 155 | else: 156 | y_onehot = None 157 | 158 | 159 | if fix: 160 | """ fixed noise """ 161 | if self.conditional: 162 | samples = self.G(self.sample_z_, y_onehot) 163 | else: 164 | samples = self.G(self.sample_z_) 165 | else: 166 | """ random noise """ 167 | sample_z_ = variable(self.random_tensor(self.sample_num, self.z_dim), volatile=True) 168 | 169 | if self.conditional: 170 | samples = self.G(sample_z_, y_onehot) 171 | else: 172 | samples = self.G(self.sample_z_) 173 | 174 | if self.input_size == 1: 175 | if self.gpu_mode: 176 | samples = samples.cpu().data.numpy() 177 | else: 178 | samples = samples.data.numpy() 179 | samples = samples.transpose(0, 2, 3, 1) 180 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 181 | dir_path + '/' + self.model_name + '_epoch%03d' % epoch + '.png') 182 | else: 183 | save_image(samples[:self.sample_num].data, dir_path + '/' + self.model_name + '_epoch%03d' % epoch + '.png', 184 | padding=0) 185 | 186 | # produce sample from all classes and return a batch of images and label 187 | # if no task2generate are given we generate all labellize for all task 188 | # if task2generate and annotate == false we generate only for the actual task 189 | # if task2generate and annotate == true we generate only for all past tasks 190 | def sample(self, batch_size=100, task2generate=None, multi_annotation=False): 191 | 192 | self.G.eval() 193 | y = None 194 | 195 | z_ = self.random_tensor(batch_size, self.z_dim) 196 | output = self.G(variable(z_)) 197 | 198 | if not (task2generate is None): 199 | self.expert.net.eval() 200 | if multi_annotation: 201 | y = self.expert.labelize(output, task2generate) 202 | 203 | else:# if we generate only from actual task 204 | y = torch.ones(batch_size, 1).long() * task2generate 205 | y = y.long() 206 | else: # if no task2generate specified 207 | # if we generate from all task 208 | y = self.expert.labelize(output, self.num_classes) 209 | 210 | return output.data, y 211 | 212 | # load a conditonal generator, encoders and discriminators 213 | def load_G(self, ind_task): 214 | self.G.load_state_dict( 215 | torch.load(os.path.join(self.save_dir, self.model_name + '-' + str(ind_task) + '_G.pkl'))) 216 | 217 | # save a generator in a given class 218 | def save_G(self, task): 219 | if not os.path.exists(self.save_dir): 220 | os.makedirs(self.save_dir) 221 | torch.save(self.G.state_dict(), os.path.join(self.save_dir, self.model_name + '-' + str(task) + '_G.pkl')) 222 | 223 | # save a generator, encoder and discriminator in a given class 224 | def save(self): 225 | if not os.path.exists(self.save_dir): 226 | os.makedirs(self.save_dir) 227 | 228 | torch.save(self.G.state_dict(), os.path.join(self.save_dir, self.model_name + '_G.pkl')) 229 | torch.save(self.D.state_dict(), os.path.join(self.save_dir, self.model_name + '_D.pkl')) 230 | 231 | with open(os.path.join(self.save_dir, self.model_name + '_history.pkl'), 'wb') as f: 232 | pickle.dump(self.train_hist, f) 233 | 234 | def train(self): 235 | self.G.train() 236 | self.D.train() 237 | 238 | def eval(self): 239 | self.G.eval() 240 | self.D.eval() 241 | 242 | def generate_batch4Task(self, nb_sample_train, task2generate, multi_annotation): 243 | return self.sample(batch_size=nb_sample_train, task2generate=task2generate, multi_annotation=multi_annotation) 244 | 245 | def create_data_loader(self, nb_sample_train, task2generate, multi_annotation): 246 | 247 | c1 = 0 248 | c2 = 1 249 | 250 | tasks_tr = [] 251 | x_tr, y_tr = self.generate_batch4Task(nb_sample_train, task2generate=task2generate, 252 | multi_annotation=multi_annotation) 253 | if self.gpu_mode: 254 | x_tr, y_tr = x_tr.cpu(), y_tr.cpu() 255 | tasks_tr.append([(c1, c2), x_tr.clone().view(-1, 784), y_tr.clone().view(-1)]) 256 | 257 | return DataLoader(tasks_tr, self.args) 258 | 259 | def generate_task(self, nb_sample_train, multi_annotation=False, classe2generate=None): 260 | 261 | 262 | if nb_sample_train >= 1000: 263 | for i in range(int(nb_sample_train / 1000)): 264 | 265 | if i == 0: 266 | data_loader = self.create_data_loader(1000, classe2generate, multi_annotation) 267 | else: 268 | new_loader = self.create_data_loader(1000, classe2generate, multi_annotation) 269 | data_loader.concatenate(new_loader) 270 | 271 | # here we generate the remaining samples 272 | if nb_sample_train % 1000 != 0: 273 | new_loader = self.create_data_loader(nb_sample_train % 1000, classe2generate, multi_annotation) 274 | data_loader.concatenate(new_loader) 275 | 276 | else: 277 | data_loader = self.create_data_loader(nb_sample_train, classe2generate, multi_annotation) 278 | 279 | return data_loader 280 | 281 | # This function generate a dataset for one class or for all class until ind_task included 282 | def generate_dataset(self, ind_task, nb_sample_per_task, one_task=True, Train=True, classe2generate=None): 283 | 284 | # to generate 10 classes classe2generate is 9 as classes 0 to 9 285 | if classe2generate is not None: 286 | assert classe2generate <= self.num_classes 287 | if self.task_type != "disjoint": 288 | assert classe2generate == self.num_classes 289 | else: 290 | classe2generate = ind_task+1 291 | 292 | train_loader_gen=None 293 | 294 | if Train: 295 | path = os.path.join(self.gen_dir, 'train_Task_' + str(ind_task) + '.pt') 296 | path_samples = os.path.join(self.sample_dir, 'samples_train_' + str(ind_task) + '.png') 297 | else: 298 | path = os.path.join(self.gen_dir, 'test_Task_' + str(ind_task) + '.pt') 299 | path_samples = os.path.join(self.sample_dir, 'samples_test_' + str(ind_task) + '.png') 300 | 301 | # if we have only on task to generate 302 | if one_task or ind_task == 0: # generate only for the task ind_task 303 | 304 | train_loader_gen = self.generate_task(nb_sample_per_task, multi_annotation=False, classe2generate=classe2generate) 305 | 306 | else: # else case we generate for all previous task 307 | 308 | for i in range(ind_task): # we generate nb_sample_per_task * (ind_task+1) samples 309 | 310 | train_loader_ind = self.generate_task(nb_sample_per_task, multi_annotation=True, classe2generate=classe2generate) 311 | 312 | if i == 0: 313 | train_loader_gen = deepcopy(train_loader_ind) 314 | else: 315 | train_loader_gen.concatenate(train_loader_ind) 316 | 317 | # we save the concatenation of all generated with the actual task for train and test 318 | train_loader_gen.save(path) 319 | train_loader_gen.visualize_sample(path_samples, self.sample_num, [self.size, self.size, self.input_size]) 320 | 321 | # return the the train loader with all data 322 | return train_loader_gen # test_loader_gen # for instance we don't use the test set 323 | 324 | # this generation only works for Baseline, disjoint 325 | # we generate the dataset based on one generator by task to get normally the best generated dataset 326 | # can be used to generate train or test data 327 | def generate_best_dataset(self, ind_task, nb_sample_per_task, one_task=True, Train=True, classe2generate=None): 328 | 329 | 330 | # to generate 10 classes classe2generate is 9 as classes 0 to 9 331 | if classe2generate is not None: 332 | assert classe2generate <= self.num_classes 333 | if self.task_type != "disjoint": 334 | assert classe2generate == self.num_classes 335 | else: 336 | classe2generate = ind_task+1 337 | 338 | if Train: 339 | path = os.path.join(self.gen_dir, 'Best_train_Task_' + str(ind_task) + '.pt') 340 | else: 341 | path = os.path.join(self.gen_dir, 'Best_test_Task_' + str(ind_task) + '.pt') 342 | 343 | # if we have only on task to generate 344 | if ind_task == 0: # generate only for the task ind_task 345 | # we do not need automatic annotation since we have one generator by class 346 | previous_data_train = self.generate_task(nb_sample_per_task, multi_annotation=False, classe2generate=classe2generate) 347 | #previous_data_train = DataLoader(tasks_tr, self.args) 348 | 349 | else: # else we load the previous dataset and add the new data 350 | 351 | previous_path_train = os.path.join(self.gen_dir, 'Best_train_Task_' + str(ind_task - 1) + '.pt') 352 | 353 | previous_data_train = DataLoader(torch.load(previous_path_train), self.args) 354 | 355 | # we do not need automatic annotation since we have one generator by class 356 | train_loader_ind = self.generate_task(nb_sample_per_task, multi_annotation=False, classe2generate=classe2generate) 357 | 358 | previous_data_train.concatenate(train_loader_ind) 359 | 360 | # we save the concatenation of all generated with the actual task for train and test 361 | previous_data_train.save(path) 362 | 363 | # return nothing 364 | 365 | -------------------------------------------------------------------------------- /Generative_Models/VAE.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import pickle 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from torch.autograd import Variable 12 | 13 | from Generative_Models.Generative_Model import GenerativeModel 14 | from Generative_Models.discriminator import Discriminator 15 | from Generative_Models.encoder import Encoder 16 | from Generative_Models.generator import Generator 17 | 18 | from utils import variable 19 | 20 | 21 | class VAE(GenerativeModel): 22 | def __init__(self, args): 23 | 24 | super(VAE, self).__init__(args) 25 | 26 | 27 | self.E = Encoder(self.z_dim, self.dataset, self.conditional) 28 | self.E_optimizer = optim.Adam(self.E.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) 29 | self.lr = args.lrD 30 | 31 | if self.gpu_mode: 32 | self.E.cuda(self.device) 33 | 34 | self.sample_z_ = variable(torch.randn((self.sample_num, self.z_dim, 1, 1))) 35 | 36 | 37 | def load_G(self, ind_task): 38 | self.G.load_state_dict( 39 | torch.load(os.path.join(self.save_dir, self.model_name + '-' + str(ind_task) + '_G.pkl'))) 40 | 41 | # self.E.load_state_dict(torch.load(os.path.join(self.save_dir, self.model_name + '_E.pkl'))) 42 | 43 | # save a generator, encoder and discriminator in a given class 44 | def save(self): 45 | if not os.path.exists(self.save_dir): 46 | os.makedirs(self.save_dir) 47 | 48 | torch.save(self.G.state_dict(), os.path.join(self.save_dir, self.model_name + '_G.pkl')) 49 | torch.save(self.E.state_dict(), os.path.join(self.save_dir, self.model_name + '_E.pkl')) 50 | 51 | with open(os.path.join(self.save_dir, self.model_name + '_history.pkl'), 'wb') as f: 52 | pickle.dump(self.train_hist, f) 53 | 54 | def random_tensor(self, batch_size, z_dim): 55 | # From Normal distribution for VAE and CVAE 56 | return torch.randn((batch_size, z_dim, 1, 1)) 57 | 58 | def train_on_task(self, train_loader, ind_task, epoch, additional_loss): 59 | 60 | start_time = time.time() 61 | 62 | n_batch = 0 63 | 64 | self.E.train() 65 | self.G.train() 66 | sum_loss_train = 0. 67 | 68 | for iter, (x_, _) in enumerate(train_loader): 69 | 70 | 71 | n_batch += 1 72 | x_ = Variable(x_) 73 | if self.gpu_mode: 74 | x_ = x_.cuda(self.device) 75 | # VAE 76 | z_, mu, logvar = self.E(x_) 77 | recon_batch = self.G(z_) 78 | # train 79 | self.G_optimizer.zero_grad() 80 | self.E_optimizer.zero_grad() 81 | g_loss = self.loss_function(recon_batch, x_.view(-1,1,28,28), mu, logvar) 82 | sum_loss_train += g_loss.item() 83 | 84 | #regularization = additional_loss([self.G, self.E]) 85 | regularization = additional_loss(self) 86 | #regularization = additional_loss(self.E) 87 | 88 | if regularization is not None: 89 | g_loss += regularization 90 | 91 | g_loss.backward() 92 | self.G_optimizer.step() 93 | self.E_optimizer.step() 94 | 95 | if self.verbose: 96 | if ((iter + 1) % 100) == 0: 97 | print("Task : [%1d] Epoch: [%2d] [%4d/%4d] G_loss: %.8f, E_loss: %.8f" % 98 | (ind_task, (epoch + 1), (iter + 1), len(train_loader), g_loss.item(), g_loss.item())) 99 | 100 | sum_loss_train = sum_loss_train / np.float(n_batch) 101 | 102 | return sum_loss_train 103 | 104 | def loss_function(self, recon_x, x, mu, logvar): 105 | if self.dataset == 'mnist' or self.dataset == 'fashion-mnist': 106 | reconstruction_function = nn.BCELoss() 107 | else: 108 | reconstruction_function = nn.MSELoss() 109 | reconstruction_function.size_average = False 110 | 111 | recon_x=recon_x.view(-1,self.input_size, self.size, self.size) 112 | x=x.view(-1,self.input_size, self.size, self.size) 113 | 114 | BCE = reconstruction_function(recon_x, x) 115 | 116 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) 117 | KLD = torch.sum(KLD_element).mul_(-0.5) 118 | 119 | if self.gpu_mode: 120 | BCE = BCE.cuda(self.device) 121 | KLD = KLD.cuda(self.device) 122 | return BCE + KLD 123 | 124 | def train(self): 125 | self.G.train() 126 | self.E.train() 127 | 128 | def eval(self): 129 | self.G.eval() 130 | self.E.eval() 131 | -------------------------------------------------------------------------------- /Generative_Models/WGAN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import torch 4 | from utils import variable 5 | from Generative_Models.Generative_Model import GenerativeModel 6 | 7 | 8 | class WGAN(GenerativeModel): 9 | def __init__(self, args): 10 | super(WGAN, self).__init__(args) 11 | self.c = 0.01 # clipping value 12 | self.n_critic = 2 # the number of iterations of the critic per generator iteration 13 | 14 | def train_on_task(self, train_loader, ind_task, epoch, additional_loss): 15 | 16 | 17 | #self.y_real_ = variable(torch.ones(self.batch_size, 1)) 18 | #self.y_fake = variable(torch.zeros(self.batch_size, 1)) 19 | 20 | self.G.train() 21 | self.D.train() 22 | 23 | epoch_start_time = time.time() 24 | sum_loss_train = 0. 25 | 26 | 27 | for iter, (x_, t_ ) in enumerate(train_loader): 28 | 29 | x_ = variable(x_.view((-1, self.input_size, self.size, self.size))) 30 | z_ = variable(torch.rand((self.batch_size, self.z_dim, 1, 1))) 31 | 32 | 33 | self.D_optimizer.zero_grad() 34 | D_real = self.D(x_) 35 | D_real_loss = -torch.mean(D_real) 36 | 37 | G_ = self.G(z_) 38 | D_fake = self.D(G_) 39 | D_fake_loss = torch.mean(D_fake) 40 | 41 | D_loss = D_real_loss + D_fake_loss 42 | 43 | D_loss.backward() 44 | self.D_optimizer.step() 45 | 46 | # clipping D 47 | for p in self.D.parameters(): 48 | p.data.clamp_(-self.c, self.c) 49 | 50 | if ((iter + 1) % self.n_critic) == 0: 51 | # update G network 52 | self.G_optimizer.zero_grad() 53 | 54 | G_ = self.G(z_) 55 | D_fake = self.D(G_) 56 | G_loss = -torch.mean(D_fake) 57 | self.train_hist['G_loss'].append(G_loss.item()) 58 | 59 | G_loss.backward() 60 | self.G_optimizer.step() 61 | 62 | self.train_hist['D_loss'].append(D_loss.item()) 63 | 64 | if self.verbose: 65 | if ((iter + 1) % 100) == 0: 66 | print("ind_task : [%1d] Epoch: [%2d] [%4d/%4d] G_loss: %.8f, D_loss: %.8f" % 67 | (ind_task, (epoch + 1), (iter + 1), self.size_epoch, G_loss.item(), D_loss.item())) 68 | 69 | 70 | #the following line is probably wrong 71 | self.train_hist['total_time'].append(time.time() - epoch_start_time) 72 | 73 | 74 | if self.verbose: 75 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), 76 | self.epoch, self.train_hist['total_time'][0])) 77 | print("Training finish!... save training results") 78 | 79 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) 80 | self.save() 81 | 82 | sum_loss_train = sum_loss_train / np.float(len(train_loader)) 83 | 84 | return sum_loss_train 85 | 86 | -------------------------------------------------------------------------------- /Generative_Models/WGAN_GP.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from torch.autograd import Variable 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import grad 9 | import torch 10 | 11 | import torch.optim as optim 12 | from utils import * 13 | 14 | from Generative_Models.Generative_Model import GenerativeModel 15 | 16 | import time 17 | 18 | 19 | class WGAN_GP(GenerativeModel): 20 | def __init__(self, args): 21 | 22 | super(WGAN_GP, self).__init__(args) 23 | 24 | 25 | self.model_name = 'WGAN_GP' 26 | self.lambda_ = 0.25 27 | 28 | # Loss weight for gradient penalty 29 | self.lambda_gp = 0.1 # 0.25 #10 30 | self.cuda = True 31 | self.c = 0.01 # clipping value 32 | self.n_critic = 2 # the number of iterations of the critic per generator iteration 33 | self.Tensor = torch.cuda.FloatTensor if True else torch.FloatTensor 34 | 35 | 36 | self.y_real_ = torch.FloatTensor([1]) 37 | self.y_fake_ = self.y_real_ * -1 38 | 39 | 40 | if self.gpu_mode: 41 | self.y_real_, self.y_fake_ = self.y_real_.cuda(self.device), self.y_fake_.cuda(self.device) 42 | 43 | 44 | 45 | def random_tensor(self, batch_size, z_dim): 46 | # From Normal distribution for VAE and CVAE 47 | return torch.randn((batch_size, z_dim, 1, 1)) 48 | 49 | def train_on_task(self, train_loader, ind_task, epoch, additional_loss): 50 | 51 | self.G.train() 52 | self.D.train() 53 | 54 | epoch_start_time = time.time() 55 | sum_loss_train = 0. 56 | 57 | for iter, (x_, t_ ) in enumerate(train_loader): 58 | 59 | for p in self.D.parameters(): # reset requires_grad 60 | p.requires_grad = True # they are set to False below in netG update 61 | 62 | x_ = variable(x_) 63 | z_ = variable(self.random_tensor(x_.size(0), self.z_dim)) 64 | 65 | # update D network 66 | self.D_optimizer.zero_grad() 67 | 68 | x_ = x_.view(-1, self.input_size, self.size, self.size) 69 | 70 | D_real = self.D(x_) 71 | D_real_loss = -torch.mean(D_real) 72 | 73 | G_ = self.G(z_) 74 | D_fake = self.D(G_) 75 | D_fake_loss = torch.mean(D_fake) 76 | 77 | # gradient penalty 78 | if self.gpu_mode: 79 | alpha = torch.rand((x_.size(0), 1, 1, 1)).cuda() 80 | else: 81 | alpha = torch.rand((x_.size(0), 1, 1, 1)) 82 | 83 | x_hat = Variable(alpha * x_.data + (1 - alpha) * G_.data, requires_grad=True) 84 | 85 | if self.gpu_mode: 86 | x_hat=x_hat.cuda() 87 | 88 | pred_hat = self.D(x_hat.view(-1, self.input_size, self.size, self.size)) 89 | 90 | gradients = grad(outputs=pred_hat, inputs=x_hat, 91 | grad_outputs=torch.ones(pred_hat.size()).cuda() if self.gpu_mode else torch.ones( 92 | pred_hat.size()), 93 | create_graph=True, retain_graph=True, only_inputs=True)[0] 94 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.lambda_gp 95 | """ 96 | if self.gpu_mode: 97 | gradients = \ 98 | grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).cuda(), 99 | create_graph=True, retain_graph=True, only_inputs=True)[0] 100 | else: 101 | gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()), 102 | create_graph=True, retain_graph=True, only_inputs=True)[0] 103 | gradient_penalty = self.lambda_gp * ( 104 | (gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean() 105 | """ 106 | D_loss = D_real_loss + D_fake_loss + gradient_penalty 107 | 108 | D_loss.backward() 109 | self.D_optimizer.step() 110 | 111 | if ((iter + 1) % self.n_critic) == 0: 112 | 113 | for p in self.D.parameters(): 114 | p.requires_grad = False # to avoid computation 115 | 116 | # update G network 117 | self.G_optimizer.zero_grad() 118 | 119 | z_ = variable(self.random_tensor(x_.size(0), self.z_dim)) 120 | 121 | G_ = self.G(z_) 122 | D_fake = self.D(G_) 123 | G_loss = -torch.mean(D_fake) 124 | 125 | G_loss.backward() 126 | self.G_optimizer.step() 127 | 128 | #the following line is probably wrong 129 | self.train_hist['total_time'].append(time.time() - epoch_start_time) 130 | 131 | 132 | if self.verbose: 133 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), 134 | self.epoch, self.train_hist['total_time'][0])) 135 | print("Training finish!... save training results") 136 | 137 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) 138 | self.save() 139 | 140 | sum_loss_train = sum_loss_train / np.float(len(train_loader)) 141 | 142 | return sum_loss_train 143 | 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /Generative_Models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TLESORT/Generative_Continual_Learning/66b121437c248993b41f154b5a2d6b7197278578/Generative_Models/__init__.py -------------------------------------------------------------------------------- /Generative_Models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import utils 6 | 7 | 8 | class Discriminator(nn.Module): 9 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 10 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 11 | def __init__(self, dataset='mnist', conditional=False, model='VAE'): 12 | super(Discriminator, self).__init__() 13 | self.dataset = dataset 14 | self.model = model 15 | self.conditional = conditional 16 | if dataset == 'mnist' or dataset == 'fashion': 17 | self.input_height = 28 18 | self.input_width = 28 19 | self.input_dim = 1 20 | self.output_dim = 1 21 | 22 | self.latent_dim = 1024 23 | 24 | shape = 128 * (self.input_height // 4) * (self.input_width // 4) 25 | 26 | self.fc1_1 = nn.Linear(784, self.latent_dim) 27 | self.fc1_2 = nn.Linear(10, self.latent_dim) 28 | self.fc2 = nn.Linear(self.latent_dim * 2, self.latent_dim // 2) 29 | self.fc2_bn = nn.BatchNorm1d(self.latent_dim // 2) 30 | self.fc3 = nn.Linear(self.latent_dim // 2, 256) 31 | self.fc3_bn = nn.BatchNorm1d(256) 32 | self.fc4 = nn.Linear(256, 1) 33 | 34 | self.conv = nn.Sequential( 35 | nn.Conv2d(self.input_dim, 64, 4, 2, 1), 36 | nn.LeakyReLU(0.2), 37 | nn.Conv2d(64, 128, 4, 2, 1), 38 | nn.BatchNorm2d(128), 39 | nn.LeakyReLU(0.2), 40 | ) 41 | self.fc = nn.Sequential( 42 | nn.Linear(shape, self.latent_dim), 43 | nn.BatchNorm1d(self.latent_dim), 44 | nn.LeakyReLU(0.2), 45 | nn.Linear(self.latent_dim, self.output_dim), 46 | nn.Sigmoid(), 47 | ) 48 | self.aux_linear = nn.Linear(shape, 10) 49 | self.softmax = nn.Softmax() 50 | self.apply(self.weights_init) 51 | 52 | if self.model == 'BEGAN': 53 | self.be_conv = nn.Sequential( 54 | nn.Conv2d(self.input_dim, 64, 4, 2, 1), 55 | nn.ReLU(), 56 | ) 57 | self.be_fc = nn.Sequential( 58 | nn.Linear(64 * (self.input_height // 2) * (self.input_width // 2), 32), 59 | nn.BatchNorm1d(32), 60 | nn.ReLU(), 61 | nn.Linear(32, 64 * (self.input_height // 2) * (self.input_width // 2)), 62 | nn.BatchNorm1d(64 * (self.input_height // 2) * (self.input_width // 2)), 63 | nn.ReLU(), 64 | ) 65 | self.be_deconv = nn.Sequential( 66 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), 67 | # nn.Sigmoid(), 68 | ) 69 | utils.initialize_weights(self) 70 | 71 | def weights_init(self, m): 72 | classname = m.__class__.__name__ 73 | if classname.find('Conv') != -1: 74 | m.weight.data.normal_(0.0, 0.02) 75 | elif classname.find('BatchNorm') != -1: 76 | m.weight.data.normal_(1.0, 0.02) 77 | m.bias.data.fill_(0) 78 | 79 | def disc_cgan(self, input, label): 80 | input = input.view(-1, self.input_height * self.input_width * self.input_dim) 81 | x = F.leaky_relu(self.fc1_1(input), 0.2) 82 | y = F.leaky_relu(self.fc1_2(label), 0.2) 83 | x = torch.cat([x, y], 1) 84 | x = F.leaky_relu(self.fc2_bn(self.fc2(x)), 0.2) 85 | x = F.leaky_relu(self.fc3_bn(self.fc3(x)), 0.2) 86 | x = F.sigmoid(self.fc4(x)) 87 | return x, label 88 | 89 | def disc_began(self, input): 90 | x = self.be_conv(input) 91 | x = x.view(x.size()[0], -1) 92 | x = self.be_fc(x) 93 | x = x.view(-1, 64, (self.input_height // 2), (self.input_width // 2)) 94 | x = self.be_deconv(x) 95 | return x 96 | 97 | def forward(self, input, c=None): 98 | # print(input.data.shape) 99 | if self.model == 'BEGAN': 100 | return self.disc_began(input) 101 | 102 | if self.model == "CGAN" or (self.model == 'GAN' and self.conditional): # CGAN 103 | return self.disc_cgan(input, c) 104 | 105 | x = self.conv(input) 106 | x = x.view(x.data.shape[0], 128 * (self.input_height // 4) * (self.input_width // 4)) 107 | 108 | final = self.fc(x) 109 | if c is not None: 110 | c = self.aux_linear(x) 111 | c = self.softmax(c) 112 | return final, c 113 | else: 114 | return final 115 | 116 | class Discriminator_Cifar(nn.Module): 117 | def __init__(self, conditional): 118 | super(Discriminator_Cifar, self).__init__() 119 | self.nc = 3 120 | self.ndf = 32 121 | self.ngpu = 1 122 | self.conditional = conditional 123 | 124 | self.main = nn.Sequential( 125 | # input is (nc) x 64 x 64 126 | #nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), 127 | #nn.LeakyReLU(0.2, inplace=True), 128 | # state size. (ndf) x 32 x 32 129 | nn.Conv2d(self.nc, self.ndf, 4, 2, 1, bias=False), 130 | nn.BatchNorm2d(self.ndf), 131 | nn.LeakyReLU(0.2, inplace=True), 132 | # state size. (ndf*2) x 16 x 16 133 | nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False), 134 | #nn.BatchNorm2d(self.ndf * 2), 135 | nn.LeakyReLU(0.2, inplace=True), 136 | # state size. (ndf*4) x 8 x 8 137 | nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False), 138 | #nn.BatchNorm2d(self.ndf * 4), 139 | nn.LeakyReLU(0.2, inplace=True), 140 | # state size. (ndf*8) x 4 x 4 141 | nn.Conv2d(self.ndf * 4, 1, 4, 1, 0, bias=False), 142 | nn.Sigmoid() 143 | ) 144 | 145 | def forward(self, input, c=None): 146 | input = input.view(-1,3,32,32) 147 | if input.is_cuda and self.ngpu > 1: 148 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 149 | else: 150 | output = self.main(input) 151 | 152 | return output.view(-1, 1).squeeze(1) 153 | 154 | def weights_init(self, m): 155 | classname = m.__class__.__name__ 156 | if classname.find('Conv') != -1: 157 | m.weight.data.normal_(0.0, 0.02) 158 | elif classname.find('BatchNorm') != -1: 159 | m.weight.data.normal_(1.0, 0.02) 160 | m.bias.data.fill_(0) 161 | -------------------------------------------------------------------------------- /Generative_Models/encoder.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | class Encoder(nn.Module): 7 | def __init__(self, z_dim, dataset='mnist', conditional=False): 8 | super(Encoder, self).__init__() 9 | self.z_dim = z_dim 10 | self.conditional = conditional 11 | if dataset == 'mnist' or dataset == 'fashion': 12 | self.input_size = 784 13 | elif dataset == 'celebA': 14 | self.input_size = 64 * 64 * 3 15 | elif dataset == 'cifar10': 16 | self.input_size = 32 * 32 * 3 17 | # self.input_size = 64 * 64 * 3 18 | elif dataset == 'timagenet': 19 | self.input_size = 64 * 64 * 3 20 | if self.conditional: 21 | self.input_size += 10 22 | self.relu = nn.ReLU() 23 | self.sigmoid = nn.Sigmoid() 24 | self.fc1 = nn.Linear(self.input_size, 1200) 25 | self.fc21 = nn.Linear(1200, z_dim) 26 | self.fc22 = nn.Linear(1200, z_dim) 27 | 28 | def encode(self, x, c=None): 29 | if self.conditional: 30 | x = torch.cat([x, c], 1) 31 | h1 = self.relu(self.fc1(x)) 32 | return self.fc21(h1), self.fc22(h1) 33 | 34 | def reparametrize(self, mu, logvar): 35 | 36 | std = logvar.mul(0.5).exp_() 37 | if torch.cuda.is_available(): 38 | eps = torch.cuda.FloatTensor(std.size()).normal_() # does not work for other device than 0 39 | else: 40 | eps = torch.FloatTensor(std.size()).normal_() 41 | eps = Variable(eps) 42 | return eps.mul(std).add_(mu) 43 | 44 | def forward(self, x, c=None): 45 | mu, logvar = self.encode(x.view(x.size(0), -1), c) 46 | z = self.reparametrize(mu, logvar) 47 | return z, mu, logvar -------------------------------------------------------------------------------- /Generative_Models/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def Generator(z_dim=62, dataset='mnist', conditional=False, model='VAE'): 6 | if dataset == 'mnist' or dataset == 'fashion': 7 | return MNIST_Generator(z_dim, dataset, conditional, model) 8 | # else: 9 | # raise ValueError("This generator is not implemented") 10 | 11 | 12 | class MNIST_Generator(nn.Module): 13 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 14 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 15 | def __init__(self, z_dim=62, dataset='mnist', conditional=False, model='VAE'): 16 | super(MNIST_Generator, self).__init__() 17 | self.dataset = dataset 18 | self.z_dim = z_dim 19 | self.model = model 20 | self.conditional = conditional 21 | 22 | self.latent_dim = 1024 23 | 24 | self.input_height = 28 25 | self.input_width = 28 26 | self.input_dim = z_dim 27 | if self.conditional: 28 | self.input_dim += 10 29 | self.output_dim = 1 30 | 31 | self.fc = nn.Sequential( 32 | nn.Linear(self.input_dim, self.latent_dim), 33 | nn.BatchNorm1d(self.latent_dim), 34 | nn.ReLU(), 35 | nn.Linear(self.latent_dim, 128 * (self.input_height // 4) * (self.input_width // 4)), 36 | nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)), 37 | nn.ReLU(), 38 | ) 39 | self.deconv = nn.Sequential( 40 | nn.ConvTranspose2d(128, 64, 4, 2, 1), 41 | nn.BatchNorm2d(64), 42 | nn.ReLU(), 43 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), 44 | nn.Sigmoid(), 45 | ) 46 | 47 | self.maxPool = nn.MaxPool2d((2, 2), stride=(2, 2)) 48 | self.Sigmoid = nn.Sigmoid() 49 | self.apply(self.weights_init) 50 | 51 | def forward(self, input, c=None): 52 | 53 | if c is not None: 54 | input = input.view(-1, self.input_dim - 10) 55 | input = torch.cat([input, c], 1) 56 | else: 57 | input = input.view(-1, self.input_dim) 58 | 59 | x = self.fc(input) 60 | x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4)) 61 | x = self.deconv(x) 62 | return x 63 | 64 | def weights_init(self, m): 65 | classname = m.__class__.__name__ 66 | if classname.find('Conv') != -1: 67 | m.weight.data.normal_(0.0, 0.02) 68 | elif classname.find('BatchNorm') != -1: 69 | m.weight.data.normal_(1.0, 0.02) 70 | m.bias.data.fill_(0) 71 | 72 | 73 | class Generator_Cifar(nn.Module): 74 | def __init__(self, z_dim, conditional=False): 75 | super(Generator_Cifar, self).__init__() 76 | self.nc = 3 77 | 78 | self.conditional = conditional 79 | 80 | self.nz = z_dim 81 | if self.conditional: 82 | self.nz += 10 83 | 84 | self.ngf = 64 85 | self.ndf = 64 86 | self.ngpu = 1 87 | 88 | self.Conv1 = nn.ConvTranspose2d(self.nz, self.ngf * 8, 4, 1, 0, bias=False) 89 | self.BN1 = nn.BatchNorm2d(self.ngf * 8) 90 | self.Relu = nn.ReLU(True) 91 | # state size. (ngf*8) x 4 x 4 92 | self.Conv2 = nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False) 93 | self.BN2 = nn.BatchNorm2d(self.ngf * 4) 94 | # nn.ReLU(True), 95 | # state size. (ngf*4) x 8 x 8 96 | self.Conv3 = nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False) 97 | self.BN3 = nn.BatchNorm2d(self.ngf * 2) 98 | # nn.ReLU(True), 99 | # state size. (ngf*2) x 16 x 16 100 | # nn.ConvTranspose2d(self.ngf * 2,self.ngf, 4, 2, 1, bias=False), 101 | self.Conv4 = nn.ConvTranspose2d(self.ngf * 2, self.nc, 4, 2, 1, bias=False) 102 | # nn.BatchNorm2d(self.ngf), 103 | # nn.ReLU(True), 104 | # state size. (ngf) x 32 x 32 105 | # nn.ConvTranspose2d(self.ngf,self.nc, 4, 2, 1, bias=False), 106 | self.Tanh = nn.Tanh() 107 | # state size. (nc) x 64 x 64 108 | 109 | ''' 110 | self.main = nn.Sequential( 111 | # input is Z, going into a convolution 112 | nn.ConvTranspose2d(self.nz, self.ngf * 8, 4, 1, 0, bias=False), 113 | nn.BatchNorm2d(self.ngf * 8), 114 | nn.ReLU(True), 115 | # state size. (ngf*8) x 4 x 4 116 | nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False), 117 | nn.BatchNorm2d(self.ngf * 4), 118 | nn.ReLU(True), 119 | # state size. (ngf*4) x 8 x 8 120 | nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False), 121 | nn.BatchNorm2d(self.ngf * 2), 122 | nn.ReLU(True), 123 | # state size. (ngf*2) x 16 x 16 124 | #nn.ConvTranspose2d(self.ngf * 2,self.ngf, 4, 2, 1, bias=False), 125 | nn.ConvTranspose2d(self.ngf * 2,self.nc, 4, 2, 1, bias=False), 126 | #nn.BatchNorm2d(self.ngf), 127 | #nn.ReLU(True), 128 | # state size. (ngf) x 32 x 32 129 | #nn.ConvTranspose2d(self.ngf,self.nc, 4, 2, 1, bias=False), 130 | nn.Tanh() 131 | # state size. (nc) x 64 x 64 132 | ) 133 | ''' 134 | 135 | def forward(self, input, c=None): 136 | 137 | if c is not None: 138 | input = input.view(-1, self.nz - 10, 1, 1) 139 | input = torch.cat([input, c], 1) 140 | else: 141 | input = input.view(-1, self.nz, 1, 1) 142 | 143 | x = self.Relu(self.BN1(self.Conv1(input))) 144 | x = self.Relu(self.BN2(self.Conv2(x))) 145 | x = self.Relu(self.BN3(self.Conv3(x))) 146 | x = self.Tanh(self.Conv4(x)) 147 | 148 | return x 149 | 150 | def weights_init(self, m): 151 | classname = m.__class__.__name__ 152 | if classname.find('Conv') != -1: 153 | m.weight.data.normal_(0.0, 0.02) 154 | elif classname.find('BatchNorm') != -1: 155 | m.weight.data.normal_(1.0, 0.02) 156 | m.bias.data.fill_(0) 157 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Timothée Lesort 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative Models from the perspective of Continual Learning 2 | *Timothée Lesort, Hugo Caselles-Dupré, Michael Garcia-Ortiz, Andrei Stoian, David Filliat*; **IJCNN 2019, Budapest** 3 | 4 | ## Abstract 5 | 6 | Which generative model is the most suitable for Continual Learning? This paper aims at evaluating and comparing generative models on disjoint sequential image generation tasks.
7 | We investigate how several models learn and forget, considering various strategies: rehearsal, regularization, generative replay and fine-tuning. We used two quantitative metrics to estimate the generation quality and memory ability. We experiment with sequential tasks on three commonly used benchmarks for Continual Learning (MNIST, Fashion MNIST and CIFAR10).
8 | We found that among all models, the original GAN performs best and among Continual Learning strategies, generative replay outperforms all other methods. Even if we found satisfactory combinations on MNIST and Fashion MNIST, training generative models sequentially on CIFAR10 is particularly instable, and remains a challenge.
9 | 10 | Sequence of Task 11 | Example of generative tasks sequence and generation capability to reach. 12 | 13 | 14 | ### Citing the Project 15 | 16 | ```Array. 17 | @inproceedings{lesort2019generative, 18 | title={Generative models from the perspective of continual learning}, 19 | author={Lesort, Timoth{\'e}e and Caselles-Dupr{\'e}, Hugo and Garcia-Ortiz, Michael and Stoian, Andrei and Filliat, David}, 20 | booktitle={2019 International Joint Conference on Neural Networks (IJCNN)}, 21 | pages={1--8}, 22 | year={2019}, 23 | organization={IEEE} 24 | } 25 | ``` 26 | ## Installation 27 | 28 | ### Clone Repos 29 | 30 | ```bash 31 | git clone https://github.com/TLESORT/Generation_Incremental.git 32 | ``` 33 | 34 | ### Create Set-up 35 | 36 | #### Manual 37 | 38 | ```bash 39 | pytorch 0.4 40 | torchvision 0.2.1 41 | imageio 2.2.0 42 | tqdm 4.19.5 43 | ``` 44 | 45 | #### Conda environmnet 46 | 47 | ```bash 48 | conda env create -f environment.yml 49 | source activate py36 50 | ``` 51 | 52 | #### Docker environmnet 53 | 54 | TODO 55 | 56 | ## Experiments Done 57 | 58 | #### Dataset 59 | 60 | * MNIST 61 | * Fashion MNIST 62 | 63 | 64 | #### Generative Models 65 | 66 | * GAN 67 | * CGAN 68 | * WGAN 69 | * WGAN_GP 70 | * VAE 71 | * CVAE 72 | 73 | #### Task 74 | 75 | * Disjoint tasks -> 10 tasks 76 | 77 | 78 | #### To Add 79 | 80 | * Cifar10 81 | 82 | ## Run experiments 83 | 84 | 85 | ```bash 86 | 87 | cd Scripts 88 | ./generate_test.sh 89 | ./test_todo.sh 90 | ``` 91 | 92 | 93 | NB : Test todo will contains all bash commands to run since it may takes some days to run them all you can choose one of them manually and run it in the main repository 94 | Manual Example of commands for training and evaluating *Generative_replay* with *GAN* on Mnist : 95 | 96 | Generate Data 97 | ```bash 98 | cd ./Data 99 | #For the expert 100 | python main_data.py --task disjoint --dataset mnist --n_tasks 1 --dir ../Archives 101 | #For the models to train 102 | python main_data.py --task disjoint --dataset mnist --n_tasks 10 --dir ../Archives 103 | #For Upperbound and FID 104 | python main_data.py --task disjoint --upperbound True --dataset mnist --n_tasks 10 --dir ../Archives 105 | 106 | # Go back to main repo 107 | cd .. 108 | ``` 109 | 110 | Train Expert to compute later FID 111 | ```bash 112 | python main.py --context Classification --task_type disjoint --method Baseline --dataset mnist --epochs 50 --epoch_Review 50 --num_task 1 --seed 0 --dir ./Archives 113 | ``` 114 | 115 | Train Generator 116 | ```bash 117 | python main.py --context Generation --task_type disjoint --method Generative_Replay --dataset mnist --epochs 50 --num_task 10 --gan_type GAN --train_G True --seed 0 --dir ./Archives 118 | ``` 119 | 120 | Review Generator with Fitting Capacity 121 | ```bash 122 | python main.py --context Generation --task_type disjoint --method Generative_Replay --dataset mnist --epochs 50 --num_task 10 --gan_type GAN --Fitting_capacity True --seed 0 --dir ./Archives 123 | ``` 124 | 125 | Review Generator with FID 126 | ```bash 127 | python main.py --context Generation --task_type disjoint --method Generative_Replay --dataset mnist --epochs 50 --num_task 10 --gan_type GAN --FID True --seed 0 --dir ./Archives 128 | ``` 129 | 130 | ## print figures 131 | 132 | Go to the main repository 133 | 134 | Plot Fitting Capacity 135 | ```bash 136 | python print_figures.py --fitting_capacity True 137 | ``` 138 | 139 | Plot FID 140 | ```bash 141 | python print_figures.py --FID True 142 | ``` 143 | 144 | 145 | 146 | 147 | 148 | 153 | 158 | 159 | 160 | 161 |
149 | Fitting capacity : GAN MNIST 150 |
151 | Fitting capacity at each task : GAN MNIST 152 |
154 | FID : GAN MNIST 155 |
156 | Fashion-Mnist at each task results. 157 |
162 | 163 | ## Plot Samples 164 | 165 | Samples MNIST 166 | 167 | -------------------------------------------------------------------------------- /Scripts/generate_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | fileName=test_todo.sh 4 | epochs=50 5 | epoch_Review=50 6 | epoch_G=1 7 | num_task=10 8 | 9 | 10 | rm $fileName 11 | echo '#!/bin/bash' >> $fileName 12 | chmod +x $fileName 13 | 14 | datasets="mnist fashion" 15 | models="WGAN VAE CVAE CGAN GAN WGAN_GP" #BEGAN 16 | methods="Baseline Generative_Replay Reharsal Ewc" # 17 | seeds="0 1 2 3 4 5 6 7" 18 | seeds="0" 19 | 20 | 21 | ########################## GENERATE INPUT DATA FILES ########################### 22 | 23 | echo 'cd ../Data' >> $fileName 24 | 25 | dir="../Archives" 26 | 27 | echo '#Generate Inpu Data' >> $fileName 28 | for dataset in $datasets ;do 29 | echo '#'$dataset >> $fileName 30 | echo '#For the expert' >> $fileName 31 | echo python main.py --task disjoint --dataset $dataset --n_tasks 1 --dir $dir >> $fileName 32 | 33 | echo '#For the models to train' >> $fileName 34 | echo python main.py --task disjoint --dataset $dataset --n_tasks $num_task --dir $dir >> $fileName 35 | 36 | echo '#For Upperbound' >> $fileName 37 | echo python main.py --task disjoint --upperbound True --dataset $dataset --n_tasks $num_task --dir $dir >> $fileName 38 | done #datasets 39 | 40 | 41 | 42 | 43 | ############### EXPERTS ###################### 44 | 45 | dir="./Archives" 46 | echo 'cd ..' >> $fileName 47 | 48 | 49 | for seed in $seeds; do 50 | 51 | 52 | #fileName=test_todo$seed\.sh 53 | 54 | #rm $fileName 55 | #echo '#!/bin/bash' >> $fileName 56 | #echo 'cd ..' >> $fileName # Go back to main folder 57 | 58 | for dataset in $datasets ;do 59 | 60 | #if [ "$dataset" == "mnist" ]; then 61 | # fileName=test_todo$seed\.sh 62 | #elif [ "$dataset" == "fashion" ]; then 63 | # fileName=test_todo$seed\bis.sh 64 | #fi 65 | #rm $fileName 66 | #echo '#!/bin/bash' >> $fileName 67 | #echo 'cd ..' >> $fileName # Go back to main folder 68 | #chmod +x $fileName 69 | 70 | 71 | 72 | echo python main.py --context Classification --task_type disjoint --method Baseline --dataset $dataset --epochs $epochs --epoch_Review $epoch_Review --num_task 1 --seed $seed --dir $dir >> $fileName 73 | ############### TEST TODO ###################### 74 | 75 | echo '#'$dataset >> $fileName 76 | 77 | for model in $models ;do 78 | echo '#'$model >> $fileName 79 | 80 | for method in $methods;do 81 | echo '#'$method >> $fileName 82 | 83 | echo python main.py --context Generation --task_type disjoint --method $method --dataset $dataset --epochs $epochs --epoch_Review $epoch_Review --epoch_G $epoch_G --num_task $num_task --gan_type $model --train_G True --seed $seed --dir $dir >> $fileName 84 | echo python main.py --context Generation --task_type disjoint --method $method --dataset $dataset --epochs $epochs --epoch_Review $epoch_Review --epoch_G $epoch_G --num_task $num_task --gan_type $model --FID True --seed $seed --dir $dir >> $fileName 85 | echo python main.py --context Generation --task_type disjoint --method $method --dataset $dataset --epochs $epochs --epoch_Review $epoch_Review --epoch_G $epoch_G --num_task $num_task --gan_type $model --Fitting_capacity True --seed $seed --dir $dir >> $fileName 86 | echo python main.py --context Generation --task_type disjoint --method $method --dataset $dataset --epochs $epochs --epoch_Review $epoch_Review --epoch_G $epoch_G --num_task $num_task --gan_type $model --trainEval True --seed $seed --dir $dir >> $fileName 87 | 88 | 89 | done #method 90 | 91 | 92 | ################ UPPERBOUND #################### 93 | 94 | 95 | echo '#Upperbound #' >> $fileName 96 | echo '#'$dataset >> $fileName 97 | 98 | echo python main.py --context Generation --task_type disjoint --method Baseline --dataset $dataset --epochs $epochs --epoch_Review $epoch_Review --epoch_G $epoch_G --num_task $num_task --gan_type $model --upperbound True --train_G True --seed $seed --dir $dir >> $fileName 99 | echo python main.py --context Generation --task_type disjoint --method Baseline --dataset $dataset --epochs $epochs --epoch_Review $epoch_Review --epoch_G $epoch_G --num_task $num_task --gan_type $model --upperbound True --FID True --seed $seed --dir $dir >> $fileName 100 | echo python main.py --context Generation --task_type disjoint --method Baseline --dataset $dataset --epochs $epochs --epoch_Review $epoch_Review --epoch_G $epoch_G --num_task $num_task --gan_type $model --upperbound True --Fitting_capacity True --seed $seed --dir $dir >> $fileName 101 | echo python main.py --context Generation --task_type disjoint --method Baseline --dataset $dataset --epochs $epochs --epoch_Review $epoch_Review --epoch_G $epoch_G --num_task $num_task --gan_type $model --upperbound True --trainEval True --seed $seed --dir $dir >> $fileName 102 | 103 | done #model 104 | 105 | done #dataset 106 | 107 | done #seed 108 | -------------------------------------------------------------------------------- /Training/Baseline.py: -------------------------------------------------------------------------------- 1 | 2 | from Training.Trainer import Trainer 3 | 4 | 5 | class Baseline(Trainer): 6 | def __init__(self, model, args, reviewer): 7 | super(Baseline, self).__init__(model, args, reviewer) 8 | 9 | self.reviewer = reviewer 10 | 11 | def preparation_4_task(self, model, ind_task): 12 | 13 | if ind_task > 0 and self.context == 'Generation': 14 | # We generate as much image as there is in the actual task 15 | #nb_sample_train = len(self.train_loader[ind_task]) 16 | #nb_sample_test = int(nb_sample_train * 0.2) 17 | 18 | #print("nb_sample_train should be higher than this") 19 | print("numbe of train sample is fixed as : " + str(self.sample_transfer)) 20 | nb_sample_train = self.sample_transfer 21 | nb_sample_test = int(nb_sample_train * 0.2) 22 | 23 | # we generate dataset for later evaluation with image from previous tasks 24 | self.model.generate_dataset(ind_task - 1, nb_sample_train, one_task=False, Train=True, classe2generate = ind_task) 25 | #self.model.generate_dataset(ind_task - 1, nb_sample_test, one_task=False, Train=False) 26 | 27 | # generate dataset with one generator by task 28 | self.model.generate_best_dataset(ind_task - 1, nb_sample_train, one_task = False, Train=True, classe2generate = ind_task) 29 | 30 | #self.model.generate_best_dataset(ind_task - 1, nb_sample_test, Train=False) 31 | 32 | train_loader, test_loader = self.create_next_data(ind_task) 33 | return train_loader, test_loader 34 | -------------------------------------------------------------------------------- /Training/Ewc.py: -------------------------------------------------------------------------------- 1 | from Training.Trainer import Trainer 2 | from torch.nn import functional as F 3 | from utils import variable 4 | import torch.nn as nn 5 | import torch 6 | import random 7 | import numpy as np 8 | from torch.autograd import Variable 9 | from copy import deepcopy 10 | 11 | 12 | class Ewc(Trainer): 13 | def __init__(self, model, args): 14 | super(Ewc, self).__init__(model, args) 15 | self.params = None 16 | self._means = None 17 | self._precision_matrices = None 18 | self.importance = args.lambda_EWC 19 | 20 | def penalty(self, model: nn.Module): 21 | 22 | if self.context == 'Classification': 23 | models = [model] 24 | elif self.context == 'Generation': 25 | if model.model_name in ['CVAE', 'VAE']: 26 | models = [model.G, model.E] 27 | elif model.model_name in ['CGAN', 'GAN', 'WGAN', 'BEGAN']: 28 | models = [model.G] 29 | 30 | loss = 0 31 | for model_ in models: 32 | for n, p in model_.named_parameters(): 33 | _loss = self._precision_matrices[n] * (p - self._means[n]) ** 2 34 | loss += _loss.sum() 35 | return loss 36 | 37 | def _diag_fisher(self, model): 38 | 39 | if self.context == 'Classification': 40 | 41 | precision_matrices = {} 42 | for n, p in deepcopy(self.params).items(): 43 | p.data.zero_() 44 | precision_matrices[n] = variable(p.data) 45 | 46 | model.eval() 47 | 48 | for input in self.old_task: 49 | 50 | model.zero_grad() 51 | input = variable(input).view(-1, 1) 52 | output = model(input).view(1, -1) 53 | label = output.max(1)[1].view(-1) 54 | loss = F.nll_loss(F.log_softmax(output, dim=1), label) 55 | loss.backward() 56 | 57 | for n, p in model.named_parameters(): 58 | precision_matrices[n].data += p.grad.data ** 2 / len(self.old_task) 59 | 60 | precision_matrices = {n: p for n, p in precision_matrices.items()} 61 | return precision_matrices 62 | 63 | if self.context == 'Generation': 64 | 65 | old_batch_size = self.train_loader.batch_size 66 | self.train_loader.batch_size = 1 67 | #model.batch_size = 1 68 | 69 | precision_matrices = {} 70 | for n, p in deepcopy(self.params).items(): 71 | p.data.zero_() 72 | precision_matrices[n] = variable(p.data) 73 | 74 | model.eval() 75 | 76 | self.y_real_ = variable(torch.ones(1, 1)) 77 | self.y_fake_ = variable(torch.zeros(1, 1)) 78 | 79 | if model.model_name in ['CVAE', 'VAE']: 80 | 81 | models = [model.G, model.E] 82 | 83 | model.E.eval() 84 | model.G.eval() 85 | 86 | for iter, (x_, t_) in enumerate(self.train_loader): 87 | 88 | self.model.E_optimizer.zero_grad() 89 | self.model.G_optimizer.zero_grad() 90 | 91 | x_ = variable(x_) 92 | 93 | if model.model_name == 'CVAE': 94 | y_onehot = variable(model.get_one_hot(t_)) 95 | z_, mu, logvar = model.E(x_, y_onehot) 96 | recon_batch = model.G(z_, y_onehot) 97 | else: 98 | z_, mu, logvar = model.E(x_) 99 | recon_batch = model.G(z_) 100 | 101 | loss = model.loss_function(recon_batch, x_, mu, logvar) 102 | loss = torch.log(loss) 103 | loss.backward() 104 | 105 | for model_ in models: 106 | for n, p in model_.named_parameters(): 107 | precision_matrices[n].data += p.grad.data ** 2 / int(len(self.train_loader)) 108 | 109 | elif model.model_name in ["GAN", "CGAN" , "WGAN"]: 110 | 111 | models = [model.G] 112 | 113 | self.model.G.eval() 114 | self.model.D.eval() 115 | 116 | for iter, (x_, t_) in enumerate(self.train_loader): 117 | 118 | self.model.G_optimizer.zero_grad() 119 | self.model.D_optimizer.zero_grad() 120 | 121 | 122 | z_ = variable(torch.rand((1, self.model.z_dim))) 123 | 124 | if model.model_name == 'CGAN': 125 | y_onehot = variable(model.get_one_hot(t_)) 126 | G_ = self.model.G(z_, y_onehot) 127 | D_fake = self.model.D(G_, y_onehot) 128 | BCELoss = nn.BCELoss() 129 | G_loss = BCELoss(D_fake[0], self.y_real_) 130 | G_loss = torch.log(G_loss) 131 | G_loss = torch.mean(G_loss) 132 | 133 | elif model.model_name == 'GAN': 134 | 135 | G_ = model.G(z_) 136 | D_fake = model.D(G_) 137 | G_loss = model.BCELoss(D_fake, self.y_real_) 138 | G_loss = torch.log(G_loss) 139 | 140 | elif model.model_name == 'WGAN': 141 | 142 | z_ = variable(torch.rand((1, model.z_dim, 1, 1))) 143 | 144 | # clipping D 145 | for p in model.D.parameters(): 146 | p.data.clamp_(-model.c, model.c) 147 | 148 | G_ = model.G(z_) 149 | D_fake = model.D(G_) 150 | D_fake = torch.log(D_fake) 151 | G_loss = -torch.mean(D_fake) 152 | 153 | G_loss.backward() 154 | 155 | for model_ in models: 156 | for n, p in model_.named_parameters(): 157 | precision_matrices[n].data += p.grad.data ** 2 / int(len(self.train_loader)) 158 | 159 | 160 | '''if model.model_name == 'BEGAN': 161 | 162 | precision_matrices = {n: p for n, p in precision_matrices.items()} 163 | return precision_matrices''' 164 | 165 | else: 166 | print('Not implemented yet') 167 | return None 168 | 169 | precision_matrices = {n: p for n, p in precision_matrices.items()} 170 | self.train_loader.batch_size = old_batch_size 171 | model.batch_size = old_batch_size 172 | return precision_matrices 173 | 174 | def additional_loss(self, model): 175 | if self.ind_task > 0: 176 | loss = self.importance * self.penalty(model) 177 | else: 178 | loss = None 179 | return loss 180 | 181 | def preparation_4_task(self, model, ind_task): 182 | 183 | if self.context == 'Classification': 184 | 185 | # Here model is only the generator. 186 | 187 | train_loader, test_loader = self.create_next_data(ind_task) 188 | 189 | self.model.train() 190 | 191 | # we only save the importance of weights at the beginning of each new task 192 | old_tasks = [] 193 | for sub_task in range(self.ind_task + 1): 194 | old_tasks = old_tasks + list(self.train_loader[sub_task].get_sample(self.samples_per_task)) 195 | old_tasks = random.sample(old_tasks, k=self.samples_per_task) 196 | self.old_task = old_tasks 197 | 198 | self.params = {n: p for n, p in model.named_parameters() if p.requires_grad} 199 | self._means = {} 200 | self._precision_matrices = self._diag_fisher(model) 201 | 202 | for n, p in deepcopy(self.params).items(): 203 | self._means[n] = variable(p.data) 204 | 205 | return train_loader, test_loader 206 | 207 | if ind_task > 0 and self.context == 'Generation': 208 | 209 | ### Compute Fischer info matrix 210 | self.params = {n: p for n, p in model.G.named_parameters() if p.requires_grad} 211 | 212 | if model.model_name in ['CVAE', 'VAE']: 213 | self.params_E = {n: p for n, p in model.E.named_parameters() if p.requires_grad} 214 | self.params = {**self.params, **self.params_E} 215 | 216 | self._means = {} 217 | precision_matrices_this_task = self._diag_fisher(model) 218 | 219 | # update fisher info with previous tasks fisher info, and this task fisher info (sum) 220 | if ind_task == 1: 221 | self._precision_matrices = precision_matrices_this_task 222 | else: 223 | self._precision_matrices = {n: p + self._precision_matrices[n] for n, p in 224 | precision_matrices_this_task.items()} 225 | 226 | for n, p in deepcopy(self.params).items(): 227 | self._means[n] = variable(p.data) 228 | 229 | ### 230 | nb_sample_train = len(self.train_loader[ind_task]) 231 | nb_sample_test = int(nb_sample_train * 0.2) 232 | 233 | # we generate dataset for later evaluation 234 | nb_sample_train = self.sample_transfer #len(self.train_loader[ind_task]) 235 | nb_sample_test = int(nb_sample_train * 0.2) 236 | self.model.generate_dataset(ind_task - 1, nb_sample_train, one_task=False, Train=True) 237 | #self.model.generate_dataset(ind_task - 1, nb_sample_test, one_task=False, Train=False) 238 | 239 | 240 | return train_loader, test_loader 241 | -------------------------------------------------------------------------------- /Training/Ewc_samples.py: -------------------------------------------------------------------------------- 1 | from Training.Trainer import Trainer 2 | from torch.nn import functional as F 3 | from utils import variable 4 | import torch.nn as nn 5 | import torch 6 | import random 7 | import numpy as np 8 | from torch.autograd import Variable 9 | from copy import deepcopy 10 | 11 | 12 | class Ewc_samples(Trainer): 13 | def __init__(self, model, args): 14 | super(Ewc_samples, self).__init__(model, args) 15 | self.params = None 16 | self._means = None 17 | self._precision_matrices = None 18 | self.importance = args.lambda_EWC 19 | 20 | def penalty(self, model: nn.Module): 21 | 22 | if self.context == 'Classification': 23 | models = [model] 24 | elif self.context == 'Generation': 25 | if model.model_name in ['CVAE', 'VAE']: 26 | models = [model.G, model.E] 27 | elif model.model_name in ['CGAN', 'GAN', 'WGAN', 'BEGAN']: 28 | models = [model.G] 29 | 30 | loss = 0 31 | for model_ in models: 32 | for n, p in model_.named_parameters(): 33 | _loss = self._precision_matrices[n] * (p - self._means[n]) ** 2 34 | loss += _loss.sum() 35 | return loss 36 | 37 | def _diag_fisher(self, model): 38 | 39 | if self.context == 'Classification': 40 | 41 | precision_matrices = {} 42 | for n, p in deepcopy(self.params).items(): 43 | p.data.zero_() 44 | precision_matrices[n] = variable(p.data) 45 | 46 | model.eval() 47 | 48 | for input in self.old_task: 49 | 50 | model.zero_grad() 51 | input = variable(input).view(-1, 1) 52 | output = model(input).view(1, -1) 53 | label = output.max(1)[1].view(-1) 54 | loss = F.nll_loss(F.log_softmax(output, dim=1), label) 55 | loss.backward() 56 | 57 | for n, p in model.named_parameters(): 58 | precision_matrices[n].data += p.grad.data ** 2 / len(self.old_task) 59 | 60 | precision_matrices = {n: p for n, p in precision_matrices.items()} 61 | return precision_matrices 62 | 63 | if self.context == 'Generation': 64 | 65 | old_batch_size = self.train_loader.batch_size 66 | self.train_loader.batch_size = 1 67 | #model.batch_size = 1 68 | 69 | precision_matrices = {} 70 | for n, p in deepcopy(self.params).items(): 71 | p.data.zero_() 72 | precision_matrices[n] = variable(p.data) 73 | 74 | model.eval() 75 | 76 | self.y_real_ = variable(torch.ones(1, 1)) 77 | self.y_fake_ = variable(torch.zeros(1, 1)) 78 | 79 | if model.model_name in ['CVAE', 'VAE']: 80 | 81 | models = [model.G, model.E] 82 | 83 | model.E.eval() 84 | model.G.eval() 85 | 86 | for iter, (x_, t_) in enumerate(self.old_task): 87 | 88 | self.model.E_optimizer.zero_grad() 89 | self.model.G_optimizer.zero_grad() 90 | 91 | x_ = variable(x_).view(1, 784) 92 | 93 | if model.model_name == 'CVAE': 94 | y_onehot = variable(model.get_one_hot(t_)) 95 | z_, mu, logvar = model.E(x_, y_onehot) 96 | recon_batch = model.G(z_, y_onehot) 97 | else: 98 | z_, mu, logvar = model.E(x_) 99 | recon_batch = model.G(z_) 100 | 101 | loss = model.loss_function(recon_batch, x_, mu, logvar) 102 | loss = torch.log(loss) 103 | loss.backward() 104 | 105 | for model_ in models: 106 | for n, p in model_.named_parameters(): 107 | precision_matrices[n].data += p.grad.data ** 2 / int(len(self.train_loader)) 108 | 109 | elif model.model_name in ["GAN", "CGAN" , "WGAN"]: 110 | 111 | models = [model.G] 112 | 113 | self.model.G.eval() 114 | self.model.D.eval() 115 | 116 | for iter, (x_, t_) in enumerate(self.old_task): 117 | 118 | self.model.G_optimizer.zero_grad() 119 | self.model.D_optimizer.zero_grad() 120 | 121 | 122 | z_ = variable(torch.rand((1, self.model.z_dim))) 123 | 124 | if model.model_name == 'CGAN': 125 | y_onehot = variable(model.get_one_hot(t_)) 126 | G_ = self.model.G(z_, y_onehot) 127 | D_fake = self.model.D(G_, y_onehot) 128 | BCELoss = nn.BCELoss() 129 | G_loss = BCELoss(D_fake[0], self.y_real_) 130 | G_loss = torch.log(G_loss) 131 | G_loss = torch.mean(G_loss) 132 | 133 | elif model.model_name == 'GAN': 134 | 135 | G_ = model.G(z_) 136 | D_fake = model.D(G_) 137 | G_loss = model.BCELoss(D_fake, self.y_real_) 138 | G_loss = torch.log(G_loss) 139 | 140 | elif model.model_name == 'WGAN': 141 | 142 | z_ = variable(torch.rand((1, model.z_dim, 1, 1))) 143 | 144 | # clipping D 145 | for p in model.D.parameters(): 146 | p.data.clamp_(-model.c, model.c) 147 | 148 | G_ = model.G(z_) 149 | D_fake = model.D(G_) 150 | D_fake = torch.log(D_fake) 151 | G_loss = -torch.mean(D_fake) 152 | 153 | G_loss.backward() 154 | 155 | for model_ in models: 156 | for n, p in model_.named_parameters(): 157 | precision_matrices[n].data += p.grad.data ** 2 / int(len(self.train_loader)) 158 | 159 | 160 | '''if model.model_name == 'BEGAN': 161 | 162 | precision_matrices = {n: p for n, p in precision_matrices.items()} 163 | return precision_matrices''' 164 | 165 | else: 166 | print('Not implemented yet') 167 | return None 168 | 169 | precision_matrices = {n: p for n, p in precision_matrices.items()} 170 | self.train_loader.batch_size = old_batch_size 171 | model.batch_size = old_batch_size 172 | return precision_matrices 173 | 174 | def additional_loss(self, model): 175 | if self.ind_task > 0: 176 | loss = self.importance * self.penalty(model) 177 | else: 178 | loss = None 179 | return loss 180 | 181 | def preparation_4_task(self, model, ind_task): 182 | 183 | 184 | if self.context == 'Classification': 185 | 186 | # Here model is only the generator. 187 | 188 | train_loader, test_loader = self.create_next_data(ind_task) 189 | 190 | self.model.train() 191 | 192 | # we only save the importance of weights at the beginning of each new task 193 | old_tasks = [] 194 | for sub_task in range(self.ind_task + 1): 195 | old_tasks = old_tasks + list(self.train_loader[sub_task].get_sample(self.samples_per_task)) 196 | old_tasks = random.sample(old_tasks, k=self.samples_per_task) 197 | self.old_task = old_tasks 198 | 199 | self.params = {n: p for n, p in model.named_parameters() if p.requires_grad} 200 | self._means = {} 201 | self._precision_matrices = self._diag_fisher(model) 202 | 203 | for n, p in deepcopy(self.params).items(): 204 | self._means[n] = variable(p.data) 205 | 206 | return train_loader, test_loader 207 | 208 | if ind_task > 0 and self.context == 'Generation': 209 | 210 | # Here model is generator + discriminator/encoder. 211 | 212 | train_loader, test_loader = self.create_next_data(ind_task) 213 | 214 | # we only save the importance of weights at the beginning of each new task 215 | old_tasks = [] 216 | for sub_task in range(self.ind_task + 1): 217 | data = list(self.train_loader[sub_task].get_sample(self.samples_per_task)) 218 | data = [(data[0][i], torch.LongTensor([data[1][i]])) for i in range(len(data[0]))] 219 | old_tasks = old_tasks + data 220 | #old_tasks = random.sample(old_tasks, k=self.samples_per_task) 221 | self.old_task = old_tasks 222 | 223 | self.params = {n: p for n, p in model.G.named_parameters() if p.requires_grad} 224 | if model.model_name in ['CVAE', 'VAE']: 225 | self.params_E = {n: p for n, p in model.E.named_parameters() if p.requires_grad} 226 | self.params = {**self.params, **self.params_E} 227 | self._means = {} 228 | self._precision_matrices = self._diag_fisher(model) 229 | 230 | for n, p in deepcopy(self.params).items(): 231 | self._means[n] = variable(p.data) 232 | 233 | # generate dataset for later evaluation 234 | nb_sample_train = self.sample_transfer #len(self.train_loader[ind_task]) 235 | nb_sample_test = int(nb_sample_train * 0.2) 236 | self.model.generate_dataset(ind_task - 1, nb_sample_train, one_task=False, Train=True) 237 | #self.model.generate_dataset(ind_task - 1, nb_sample_test, one_task=False, Train=False) 238 | 239 | return train_loader, test_loader 240 | 241 | train_loader, test_loader = self.create_next_data(ind_task) 242 | return train_loader, test_loader 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | ''' 254 | 255 | if ind_task > 0 and self.context == 'Generation': 256 | 257 | ### Compute Fischer info matrix 258 | self.params = {n: p for n, p in model.G.named_parameters() if p.requires_grad} 259 | 260 | if model.model_name in ['CVAE', 'VAE']: 261 | self.params_E = {n: p for n, p in model.E.named_parameters() if p.requires_grad} 262 | self.params = {**self.params, **self.params_E} 263 | 264 | self._means = {} 265 | precision_matrices_this_task = self._diag_fisher(model) 266 | 267 | # update fisher info with previous tasks fisher info, and this task fisher info (sum) 268 | if ind_task == 1: 269 | self._precision_matrices = precision_matrices_this_task 270 | else: 271 | self._precision_matrices = {n: p + self._precision_matrices[n] for n, p in 272 | precision_matrices_this_task.items()} 273 | 274 | for n, p in deepcopy(self.params).items(): 275 | self._means[n] = variable(p.data) 276 | 277 | ### 278 | nb_sample_train = len(self.train_loader[ind_task]) 279 | nb_sample_test = int(nb_sample_train * 0.2) 280 | 281 | # we generate dataset for later evaluation 282 | self.model.generate_dataset(ind_task - 1, nb_sample_train, nb_sample_test, one_task=False) 283 | 284 | train_loader, test_loader = self.create_next_data(ind_task) 285 | 286 | return train_loader, test_loader 287 | 288 | train_loader, test_loader = self.create_next_data(ind_task) 289 | return train_loader, test_loader''' 290 | -------------------------------------------------------------------------------- /Training/Generative_Replay.py: -------------------------------------------------------------------------------- 1 | from Training.Trainer import Trainer 2 | 3 | 4 | class Generative_Replay(Trainer): 5 | def __init__(self, model, args): 6 | super(Generative_Replay, self).__init__(model, args) 7 | 8 | 9 | 10 | 11 | def create_next_data(self, ind_task): 12 | 13 | task_te_gen = None 14 | 15 | if ind_task > 0: 16 | 17 | #nb_sample_train = len(self.train_loader[ind_task]) 18 | #nb_sample_test = int(nb_sample_train * 0.2) 19 | 20 | self.train_loader[ind_task] #we set the good index of dataset 21 | #self.test_loader[ind_task] #we set the good index of dataset 22 | 23 | #print("nb_sample_train should be higher than this") 24 | print("numbe of train sample per task is fixed as : " + str(self.sample_transfer)) 25 | 26 | nb_sample_train = self.sample_transfer # approximate size of one task 27 | #nb_sample_test = int(nb_sample_train * 0.2) 28 | 29 | #task_tr_gen = self.model.generate_dataset(ind_task - 1, nb_sample_train, one_task=False, Train=True) 30 | task_tr_gen = self.generate_dataset(ind_task, nb_sample_train, classe2generate=ind_task, 31 | Train=True) 32 | #task_tr_gen = self.model.generate_dataset(ind_task - 1, nb_sample_test, one_task=False, Train=True) 33 | 34 | self.train_loader.concatenate(task_tr_gen) 35 | train_loader = self.train_loader[ind_task] 36 | train_loader.shuffle_task() 37 | 38 | if task_te_gen is not None: 39 | self.test_loader.concatenate(task_te_gen) 40 | test_loader = self.test_loader[ind_task] 41 | test_loader.shuffle_task() 42 | else: 43 | test_loader = None #we don't use test loader for instance but we keep the code for later in case of 44 | 45 | 46 | else: 47 | train_loader = self.train_loader[ind_task] 48 | test_loader = self.test_loader[ind_task] 49 | 50 | return train_loader, test_loader 51 | -------------------------------------------------------------------------------- /Training/README.md: -------------------------------------------------------------------------------- 1 | In this folder are defined the different methods to train all the generator / Incrementaly or not 2 | -------------------------------------------------------------------------------- /Training/Rehearsal.py: -------------------------------------------------------------------------------- 1 | from Training.Trainer import Trainer 2 | from Data.data_loader import DataLoader 3 | 4 | 5 | class Rehearsal(Trainer): 6 | def __init__(self, model, args): 7 | super(Rehearsal, self).__init__(model, args) 8 | self.nb_samples_reharsal = args.nb_samples_reharsal 9 | self.data_memory = None 10 | self.task_samples = None 11 | self.task_labels = None 12 | 13 | def create_next_data(self, ind_task): 14 | # 15 | # save sample before modification of training set 16 | x_tr, y_tr = self.train_loader[ind_task].get_sample(self.nb_samples_reharsal) 17 | if self.gpu_mode: 18 | x_tr, y_tr = x_tr.cpu(), y_tr.cpu() 19 | self.task_samples = x_tr.clone() 20 | self.task_labels = y_tr.clone() 21 | 22 | # create data loder with memory fro; previous task 23 | if ind_task > 0: 24 | 25 | # balanced the number of sample and incorporate it in the memory 26 | 27 | # put the memory inside the training dataset 28 | self.train_loader[ind_task].concatenate(self.data_memory) 29 | self.train_loader[ind_task].shuffle_task() 30 | train_loader = self.train_loader[ind_task] 31 | test_loader = None 32 | 33 | else: 34 | train_loader = self.train_loader[ind_task] 35 | test_loader = None 36 | # test_loader = self.test_loader[ind_task] 37 | 38 | # Add data to memory at the end 39 | c1 = 0 40 | c2 = 1 41 | tasks_tr = [] # reset the list 42 | 43 | # save samples from the actual task in the memory 44 | 45 | tasks_tr.append([(c1, c2), self.task_samples.clone().view(-1, 784), self.task_labels.clone().view(-1)]) 46 | increase_factor = int(self.sample_transfer / self.nb_samples_reharsal) 47 | if ind_task <= 0: 48 | self.data_memory = DataLoader(tasks_tr, self.args).increase_size(increase_factor) 49 | else: 50 | self.data_memory.concatenate(DataLoader(tasks_tr, self.args).increase_size(increase_factor)) 51 | 52 | return train_loader, test_loader 53 | -------------------------------------------------------------------------------- /Training/Trainer.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from utils import * 4 | from log_utils import * 5 | from tqdm import tqdm 6 | from Data.data_loader import DataLoader 7 | import time 8 | import numpy as np 9 | from Evaluation.Reviewer import Reviewer 10 | 11 | 12 | class Trainer(object): 13 | def __init__(self,model, args, reviewer=None): 14 | self.args=args 15 | 16 | self.context=args.context 17 | 18 | if self.context=="Generation": 19 | self.reviewer = reviewer 20 | 21 | self.conditional = args.conditional 22 | self.dataset = args.dataset 23 | self.batch_size=args.batch_size 24 | self.gpu_mode=args.gpu_mode 25 | self.device=args.device 26 | self.verbose=args.verbose 27 | 28 | if self.dataset=="mnist" or self.dataset=="fashion": 29 | self.image_size = 28 30 | self.input_size = 1 31 | elif self.dataset=="cifar10": 32 | self.image_size = 32 33 | self.input_size = 3 34 | 35 | self.model=model 36 | 37 | self.sample_dir = args.sample_dir 38 | self.sample_transfer = args.sample_transfer 39 | 40 | self.sample_num = 100 41 | self.save_dir = args.save_dir 42 | self.result_dir = args.result_dir 43 | self.log_dir = args.log_dir 44 | 45 | self.task_type = args.task_type 46 | self.method = args.method 47 | self.batch_size = args.batch_size 48 | self.epochs = args.epochs 49 | self.epochs_gan = args.epoch_G 50 | self.task = None 51 | self.old_task = None 52 | self.num_task = args.num_task 53 | self.num_classes = args.num_classes 54 | self.ind_task=0 55 | self.samples_per_task=args.samples_per_task 56 | 57 | train_loader, test_loader, n_inputs, n_outputs, n_tasks = load_datasets(args) 58 | self.train_loader = DataLoader(train_loader, args) 59 | self.test_loader = DataLoader(test_loader, args) 60 | 61 | 62 | def forward(self, x, ind_task): 63 | return self.model.net.forward(x) 64 | 65 | def additional_loss(self, model): 66 | return None 67 | 68 | def create_next_data(self, ind_task): 69 | return self.train_loader[ind_task], self.test_loader[ind_task] 70 | 71 | def preparation_4_task(self, model, ind_task): 72 | 73 | if ind_task > 0 and self.context == 'Generation': 74 | # We generate as much image as there is in the actual task 75 | 76 | #nb_sample_train = len(self.train_loader[ind_task]) 77 | print("numbe of train sample is fixed as : " + str(self.sample_transfer)) 78 | nb_sample_train = self.sample_transfer # approximate size of one task 79 | nb_sample_test = int(nb_sample_train * 0.2) # not used in fact 80 | 81 | # we generate dataset for later evaluation with image from previous tasks 82 | self.model.generate_dataset(ind_task - 1, nb_sample_train, one_task=False, Train=True) 83 | self.model.generate_dataset(ind_task - 1, nb_sample_test, one_task=False, Train=False) 84 | 85 | train_loader, test_loader = self.create_next_data(ind_task) 86 | return train_loader, test_loader 87 | 88 | def run_generation_tasks(self): 89 | # run each task for a model 90 | 91 | self.model.G.apply(self.model.G.weights_init) 92 | loss, acc, acc_all_tasks = {}, {}, {} 93 | timestamp = time.time() 94 | log_time = [] 95 | for ind_task in range(self.args.num_task): 96 | 97 | print("Task : " + str(ind_task)) 98 | 99 | if 'Ewc' in self.method: 100 | train_loader, test_loader = self.preparation_4_task(self.model, ind_task) 101 | else: 102 | train_loader, test_loader = self.preparation_4_task(self.model.G, ind_task) 103 | self.ind_task=ind_task 104 | 105 | #self.visualize_Samples(train_loader, ind_task) 106 | 107 | path = os.path.join(self.sample_dir, 'sample_' + str(ind_task) + '.png') 108 | 109 | if self.verbose: 110 | print("some sample from the train_laoder") 111 | self.train_loader.visualize_sample(path, self.sample_num, [self.image_size, self.image_size, self.input_size]) 112 | 113 | loss[ind_task] = [] 114 | acc[ind_task] = [] 115 | acc_all_tasks[ind_task] = [] 116 | start_time = time.time() 117 | 118 | 119 | for epoch in range(self.args.epochs): 120 | print("Epoch : "+ str(epoch)) 121 | 122 | loss_epoch = self.model.train_on_task(train_loader, ind_task, epoch, self.additional_loss) 123 | 124 | self.model.visualize_results((epoch + 1), ind_task) 125 | 126 | loss[ind_task].append(loss_epoch) 127 | 128 | # Eval the FID 129 | 130 | ''' 131 | if 'upperbound' in self.task_type: 132 | test_file = self.task_type + '_' + str(self.num_task) + '_test.pt' 133 | else: 134 | test_file = 'upperbound_' + self.task_type + '_' + str(self.num_task) + '_test.pt' 135 | gen_DataLoader = self.model.generate_task(ind_task, nb_sample_train=1000) 136 | true_DataLoader = self.test_loader[ind_task] 137 | FID_epoch = self.reviewer.Frechet_Inception_Distance(gen_DataLoader, true_DataLoader, ind_task) 138 | ''' 139 | 140 | #for previous_task in range(ind_task + 1): 141 | #acc[previous_task].append(self.test_G(previous_task)) 142 | #acc_all_tasks[ind_task].append(self.test_G_all_tasks()) 143 | # Or save generator 144 | self.model.save_G(self.ind_task) 145 | 146 | log_time.append(time.time() - timestamp) 147 | timestamp = time.time() 148 | 149 | np.savetxt(os.path.join(self.log_dir,'task_training_time.txt'), log_time) 150 | 151 | 152 | #nb_sample_train = len(self.train_loader[0]) 153 | nb_sample_train = self.sample_transfer # approximate size of one task 154 | nb_sample_test = int(nb_sample_train * 0.2) 155 | 156 | # generate dataset for all task (indice of last task is num_task-1) and save it 157 | self.model.generate_dataset(self.num_task-1, nb_sample_train, one_task=False, Train=True) 158 | #self.model.generate_dataset(self.num_task-1, nb_sample_test, one_task=False, Train=False) 159 | 160 | if self.method == 'Baseline': # this kind of thing should not exist, inheritance should avoid it 161 | self.model.generate_best_dataset(self.num_task - 1, nb_sample_train, Train=True) 162 | #self.model.generate_best_dataset(self.num_task - 1, nb_sample_test, Train=True) 163 | 164 | 165 | 166 | def run_classification_tasks(self): 167 | accuracy_test = 0 168 | loss, acc, acc_all_tasks = {}, {}, {} 169 | for ind_task in range(self.num_task): 170 | accuracy_task = 0 171 | train_loader, test_loader = self.preparation_4_task(self.model.net, ind_task) 172 | 173 | self.ind_task=ind_task 174 | 175 | if not self.args.task_type == "CUB200": 176 | path = os.path.join(self.sample_dir, 'sample_' + str(ind_task) + '.png') 177 | 178 | if self.verbose: 179 | print("some sample from the train_loader") 180 | self.train_loader.visualize_sample(path, self.sample_num, [self.image_size, self.image_size, self.input_size]) 181 | else: 182 | print("visualisation of CUB200 not implemented") 183 | loss[ind_task] = [] 184 | acc[ind_task] = [] 185 | acc_all_tasks[ind_task] = [] 186 | for epoch in tqdm(range(self.args.epochs)): 187 | loss_epoch, accuracy_epoch = self.model.train_on_task(train_loader, ind_task, epoch, self.additional_loss) 188 | loss[ind_task].append(loss_epoch) 189 | 190 | if accuracy_epoch > accuracy_task: 191 | self.model.save(ind_task) 192 | accuracy_task = accuracy_epoch 193 | 194 | for previous_task in range(ind_task + 1): 195 | loss_test, test_acc, classe_prediction, classe_total, classe_wrong = self.model.eval_on_task( 196 | self.test_loader[previous_task], 0) 197 | 198 | #acc[previous_task].append(self.test(previous_task)) 199 | acc[previous_task].append(test_acc) 200 | 201 | accuracy_test_epoch=self.test_all_tasks() 202 | acc_all_tasks[ind_task].append(accuracy_test_epoch) 203 | 204 | if accuracy_test_epoch > accuracy_test: 205 | #if True: 206 | self.model.save(ind_task, Best=True) 207 | accuracy_test = accuracy_test_epoch 208 | 209 | 210 | loss_plot(loss, self.args) 211 | accuracy_plot(acc, self.args) 212 | accuracy_all_plot(acc_all_tasks, self.args) 213 | 214 | 215 | def test_all_tasks(self): 216 | self.model.net.eval() 217 | 218 | mean_task = 0 219 | if self.task_type == 'upperbound': 220 | loss, mean_task, classe_prediction, classe_total, classe_wrong = self.model.eval_on_task( 221 | self.test_loader[self.num_task - 1], 0) 222 | else: 223 | for ind_task in range(self.num_task): 224 | loss, acc_task, classe_prediction, classe_total, classe_wrong = self.model.eval_on_task( 225 | self.test_loader[ind_task], 0) 226 | 227 | mean_task += acc_task 228 | mean_task = mean_task/self.num_task 229 | print("Mean overall performance : " + str(mean_task.item())) 230 | return mean_task 231 | 232 | 233 | def regenerate_datasets_for_eval(self): 234 | 235 | nb_sample_train = self.sample_transfer #len(self.train_loader[0]) 236 | #nb_sample_test = int(nb_sample_train * 0.2) 237 | 238 | for i in range(self.args.num_task): 239 | self.model.load_G(ind_task=i) 240 | self.generate_dataset(i, nb_sample_train, classe2generate=i+1, Train=True) 241 | 242 | return 243 | 244 | def generate_dataset(self, ind_task,sample_per_classes, classe2generate, Train=True): 245 | return self.model.generate_dataset(ind_task, sample_per_classes, one_task=False, Train=Train, classe2generate=classe2generate) 246 | 247 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: py36 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _license=1.1=py36_1 7 | - alabaster=0.7.10=py36_0 8 | - anaconda-client=1.6.3=py36_0 9 | - anaconda=custom=py36hbbc8b67_0 10 | - anaconda-navigator=1.6.2=py36_0 11 | - anaconda-project=0.6.0=py36_0 12 | - asn1crypto=0.22.0=py36_0 13 | - astroid=1.4.9=py36_0 14 | - astropy=1.3.2=np112py36_0 15 | - babel=2.4.0=py36_0 16 | - backports=1.0=py36_0 17 | - beautifulsoup4=4.6.0=py36_0 18 | - bitarray=0.8.1=py36_0 19 | - blas=1.0=mkl 20 | - blaze=0.10.1=py36_0 21 | - bleach=1.5.0=py36_0 22 | - bokeh=0.12.5=py36_1 23 | - boto=2.46.1=py36_0 24 | - bottleneck=1.2.1=np112py36_0 25 | - cairo=1.14.8=0 26 | - cffi=1.10.0=py36_0 27 | - chardet=3.0.3=py36_0 28 | - click=6.7=py36_0 29 | - cloudpickle=0.2.2=py36_0 30 | - clyent=1.2.2=py36_0 31 | - colorama=0.3.9=py36_0 32 | - contextlib2=0.5.5=py36_0 33 | - cryptography=1.8.1=py36_0 34 | - cudatoolkit=8.0=3 35 | - cudnn=6.0.21=cuda8.0_0 36 | - curl=7.52.1=0 37 | - cycler=0.10.0=py36_0 38 | - cython=0.25.2=py36_0 39 | - cytoolz=0.8.2=py36_0 40 | - dask=0.14.3=py36_1 41 | - datashape=0.5.4=py36_0 42 | - dbus=1.10.10=0 43 | - decorator=4.0.11=py36_0 44 | - distributed=1.16.3=py36_0 45 | - docutils=0.13.1=py36_0 46 | - entrypoints=0.2.2=py36_1 47 | - et_xmlfile=1.0.1=py36_0 48 | - expat=2.1.0=0 49 | - fastcache=1.0.2=py36_1 50 | - flask=0.12.2=py36_0 51 | - flask-cors=3.0.2=py36_0 52 | - fontconfig=2.12.1=3 53 | - freetype=2.5.5=2 54 | - get_terminal_size=1.0.0=py36_0 55 | - gevent=1.2.1=py36_0 56 | - glib=2.50.2=1 57 | - greenlet=0.4.12=py36_0 58 | - gst-plugins-base=1.8.0=0 59 | - gstreamer=1.8.0=0 60 | - h5py=2.7.0=np112py36_0 61 | - harfbuzz=0.9.39=2 62 | - hdf5=1.8.17=1 63 | - heapdict=1.0.0=py36_1 64 | - html5lib=0.999=py36_0 65 | - icu=54.1=0 66 | - idna=2.5=py36_0 67 | - imagesize=0.7.1=py36_0 68 | - intel-openmp=2019.0=118 69 | - ipykernel=4.6.1=py36_0 70 | - ipython=5.3.0=py36_0 71 | - ipython_genutils=0.2.0=py36_0 72 | - ipywidgets=6.0.0=py36_0 73 | - isort=4.2.5=py36_0 74 | - itsdangerous=0.24=py36_0 75 | - jbig=2.1=0 76 | - jdcal=1.3=py36_0 77 | - jedi=0.10.2=py36_2 78 | - jinja2=2.9.6=py36_0 79 | - jpeg=9b=0 80 | - jsonschema=2.6.0=py36_0 81 | - jupyter=1.0.0=py36_3 82 | - jupyter_client=5.0.1=py36_0 83 | - jupyter_console=5.1.0=py36_0 84 | - jupyter_core=4.3.0=py36_0 85 | - lazy-object-proxy=1.2.2=py36_0 86 | - libffi=3.2.1=1 87 | - libgcc=4.8.5=2 88 | - libgcc-ng=8.2.0=hdf63c60_1 89 | - libgfortran=3.0.0=1 90 | - libgfortran-ng=7.3.0=hdf63c60_0 91 | - libiconv=1.14=0 92 | - libpng=1.6.27=0 93 | - libsodium=1.0.10=0 94 | - libstdcxx-ng=8.2.0=hdf63c60_1 95 | - libtiff=4.0.6=3 96 | - libtool=2.4.2=0 97 | - libxcb=1.12=1 98 | - libxml2=2.9.4=0 99 | - libxslt=1.1.29=0 100 | - llvmlite=0.18.0=py36_0 101 | - locket=0.2.0=py36_1 102 | - lxml=3.7.3=py36_0 103 | - markupsafe=0.23=py36_2 104 | - matplotlib=2.0.2=np112py36_0 105 | - mistune=0.7.4=py36_0 106 | - mkl=2019.0=118 107 | - mkl-service=1.1.2=py36_3 108 | - mpmath=0.19=py36_1 109 | - msgpack-python=0.4.8=py36_0 110 | - multipledispatch=0.4.9=py36_0 111 | - navigator-updater=0.1.0=py36_0 112 | - nbconvert=5.1.1=py36_0 113 | - nbformat=4.3.0=py36_0 114 | - nccl=1.3.4=cuda8.0_1 115 | - networkx=1.11=py36_0 116 | - ninja=1.8.2=py36h6bb024c_1 117 | - nltk=3.2.3=py36_0 118 | - nose=1.3.7=py36_1 119 | - notebook=5.0.0=py36_0 120 | - numba=0.33.0=np112py36_0 121 | - numexpr=2.6.8=py36hd89afb7_0 122 | - numpy=1.12.1=py36he24570b_1 123 | - numpydoc=0.6.0=py36_0 124 | - odo=0.5.0=py36_1 125 | - olefile=0.44=py36_0 126 | - openpyxl=2.4.7=py36_0 127 | - openssl=1.0.2l=0 128 | - packaging=16.8=py36_0 129 | - pandas=0.20.1=np112py36_0 130 | - pandocfilters=1.4.1=py36_0 131 | - pango=1.40.3=1 132 | - partd=0.3.8=py36_0 133 | - path.py=10.3.1=py36_0 134 | - pathlib2=2.2.1=py36_0 135 | - patsy=0.4.1=py36_0 136 | - pcre=8.39=1 137 | - pep8=1.7.0=py36_0 138 | - pexpect=4.2.1=py36_0 139 | - pickleshare=0.7.4=py36_0 140 | - pillow=4.1.1=py36_0 141 | - pip=9.0.1=py36_1 142 | - pixman=0.34.0=0 143 | - ply=3.10=py36_0 144 | - prompt_toolkit=1.0.14=py36_0 145 | - psutil=5.2.2=py36_0 146 | - ptyprocess=0.5.1=py36_0 147 | - py=1.4.33=py36_0 148 | - pycosat=0.6.2=py36_0 149 | - pycparser=2.17=py36_0 150 | - pycrypto=2.6.1=py36_6 151 | - pycurl=7.43.0=py36_2 152 | - pyflakes=1.5.0=py36_0 153 | - pygments=2.2.0=py36_0 154 | - pylint=1.6.4=py36_1 155 | - pyodbc=4.0.16=py36_0 156 | - pyopenssl=17.0.0=py36_0 157 | - pyparsing=2.1.4=py36_0 158 | - pyqt=5.6.0=py36_2 159 | - pytables=3.3.0=np112py36_0 160 | - pytest=3.0.7=py36_0 161 | - python=3.6.1=2 162 | - python-dateutil=2.6.0=py36_0 163 | - pytz=2017.2=py36_0 164 | - pywavelets=0.5.2=np112py36_0 165 | - pyyaml=3.12=py36_0 166 | - pyzmq=16.0.2=py36_0 167 | - qt=5.6.2=4 168 | - qtawesome=0.4.4=py36_0 169 | - qtconsole=4.3.0=py36_0 170 | - qtpy=1.2.1=py36_0 171 | - readline=6.2=2 172 | - requests=2.14.2=py36_0 173 | - rope=0.9.4=py36_1 174 | - ruamel_yaml=0.11.14=py36_1 175 | - scikit-image=0.13.0=np112py36_0 176 | - scikit-learn=0.19.1=py36hedc7406_0 177 | - scipy=1.1.0=py36hd20e5f9_0 178 | - seaborn=0.7.1=py36_0 179 | - setuptools=27.2.0=py36_0 180 | - simplegeneric=0.8.1=py36_1 181 | - singledispatch=3.4.0.3=py36_0 182 | - sip=4.18=py36_0 183 | - six=1.10.0=py36_0 184 | - snowballstemmer=1.2.1=py36_0 185 | - sortedcollections=0.5.3=py36_0 186 | - sortedcontainers=1.5.7=py36_0 187 | - sphinx=1.5.6=py36_0 188 | - spyder=3.1.4=py36_0 189 | - sqlalchemy=1.1.9=py36_0 190 | - sqlite=3.13.0=0 191 | - statsmodels=0.8.0=np112py36_0 192 | - sympy=1.0=py36_0 193 | - tblib=1.3.2=py36_0 194 | - terminado=0.6=py36_0 195 | - testpath=0.3=py36_0 196 | - tk=8.5.18=0 197 | - toolz=0.8.2=py36_0 198 | - tornado=4.5.1=py36_0 199 | - traitlets=4.3.2=py36_0 200 | - unicodecsv=0.14.1=py36_0 201 | - unixodbc=2.3.4=0 202 | - wcwidth=0.1.7=py36_0 203 | - werkzeug=0.12.2=py36_0 204 | - wheel=0.29.0=py36_0 205 | - widgetsnbextension=2.0.0=py36_0 206 | - wrapt=1.10.10=py36_0 207 | - xlrd=1.0.0=py36_0 208 | - xlsxwriter=0.9.6=py36_0 209 | - xlwt=1.2.0=py36_0 210 | - xz=5.2.2=1 211 | - yaml=0.1.6=0 212 | - zeromq=4.1.5=0 213 | - zict=0.1.2=py36_0 214 | - zlib=1.2.8=3 215 | - cuda80=1.0=h205658b_0 216 | - pytorch=0.4.1=py36_cuda8.0.61_cudnn7.1.2_1 217 | - torchvision=0.2.1=py36_1 218 | prefix: /home/timothee/anaconda2/envs/py36 219 | 220 | -------------------------------------------------------------------------------- /log_utils.py: -------------------------------------------------------------------------------- 1 | import os, gzip, torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import scipy.misc 5 | import imageio 6 | import matplotlib as mpl 7 | 8 | mpl.use('Agg') 9 | import matplotlib.pyplot as plt 10 | from torchvision import datasets, transforms 11 | from torch.autograd import Variable 12 | import datetime 13 | 14 | 15 | # first try for the log function it will be necessary to update it 16 | def log_test_done(args, state='end'): 17 | f1 = open('test_done.txt', 'a') 18 | if args.context == "Generation": 19 | if args.train_G and state == 'Intermediate': 20 | f1.write('TrainG-{}-{}-{}-{}-{}-{}-{}\n'.format(args.seed, args.dataset, args.gan_type, args.method, 21 | args.context, str(args.upperbound), 22 | datetime.datetime.now())) 23 | elif args.eval or (args.train_G and state == 'End'): 24 | if args.FID: 25 | f1.write('FID-{}-{}-{}-{}-{}-{}\n'.format(args.seed, args.dataset, args.gan_type, args.method, 26 | str(args.upperbound), datetime.datetime.now())) 27 | if args.Fitting_capacity: 28 | f1.write('Fitting_Capacity-{}-{}-{}-{}-{}-{}\n'.format(args.seed, args.dataset, args.gan_type, args.method, 29 | str(args.upperbound), datetime.datetime.now())) 30 | else: 31 | f1.write('Log undefined -{}-{}-{}-{}-{}-{}\n'.format(args.seed, args.dataset, args.gan_type, args.method, 32 | str(args.upperbound), datetime.datetime.now())) 33 | elif args.context == "Classification": 34 | if args.eval: 35 | f1.write( 36 | 'Classification-{}-{}-{}-{}-{}\n'.format(args.seed, args.dataset, args.method, str(args.upperbound), 37 | datetime.datetime.now())) 38 | else: 39 | f1.write( 40 | 'Log undefined2 -{}-{}-{}-{}-{}\n'.format(args.seed, args.dataset, args.method, str(args.upperbound), 41 | datetime.datetime.now())) 42 | f1.close() 43 | 44 | 45 | def img_stretch(img): 46 | img = img.astype(float) 47 | img -= np.min(img) 48 | img /= np.max(img) + 1e-12 49 | return img 50 | 51 | 52 | def make_samples_batche(prediction, batch_size, filename_dest): 53 | plt.figure() 54 | batch_size_sqrt = int(np.sqrt(batch_size)) 55 | input_channel = prediction[0].shape[0] 56 | input_dim = prediction[0].shape[1] 57 | prediction = np.clip(prediction, 0, 1) 58 | pred = np.rollaxis(prediction.reshape((batch_size_sqrt, batch_size_sqrt, input_channel, input_dim, input_dim)), 2, 59 | 5) 60 | pred = pred.swapaxes(2, 1) 61 | pred = pred.reshape((batch_size_sqrt * input_dim, batch_size_sqrt * input_dim, input_channel)) 62 | fig, ax = plt.subplots(figsize=(batch_size_sqrt, batch_size_sqrt)) 63 | ax.axis('off') 64 | ax.imshow(img_stretch(pred), interpolation='nearest') 65 | ax.grid() 66 | ax.set_xticks([]) 67 | ax.set_yticks([]) 68 | fig.savefig(filename_dest, bbox_inches='tight', pad_inches=0) 69 | plt.close(fig) 70 | plt.close() 71 | 72 | 73 | def save_images(images, size, image_path): 74 | return imsave(images, size, image_path) 75 | 76 | 77 | def imsave(images, size, path): 78 | image = np.squeeze(merge(images, size)) 79 | return scipy.misc.imsave(path, image) 80 | 81 | 82 | def merge(images, size): 83 | h, w = images.shape[1], images.shape[2] 84 | if (images.shape[3] in (3, 4)): 85 | c = images.shape[3] 86 | img = np.zeros((h * size[0], w * size[1], c)) 87 | for idx, image in enumerate(images): 88 | i = idx % size[1] 89 | j = idx // size[1] 90 | img[j * h:j * h + h, i * w:i * w + w, :] = image 91 | return img 92 | elif images.shape[3] == 1: 93 | img = np.zeros((h * size[0], w * size[1])) 94 | for idx, image in enumerate(images): 95 | i = idx % size[1] 96 | j = idx // size[1] 97 | img[j * h:j * h + h, i * w:i * w + w] = image[:, :, 0] / 255.0 98 | return img 99 | else: 100 | raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4') 101 | 102 | 103 | def generate_animation(path, num): 104 | images = [] 105 | for e in range(num): 106 | img_name = path + '_epoch%03d' % (e + 1) + '.png' 107 | images.append(imageio.imread(img_name)) 108 | imageio.mimsave(path + '_generate_animation.gif', images, fps=5) 109 | 110 | 111 | def loss_G_plot(hist, path='', model_name=''): 112 | x = range(len(hist['D_loss'])) 113 | 114 | y1 = hist['D_loss'] 115 | y2 = hist['G_loss'] 116 | 117 | plt.plot(x, y1, label='D_loss') 118 | plt.plot(x, y2, label='G_loss') 119 | 120 | plt.xlabel('Iter') 121 | plt.ylabel('Loss') 122 | 123 | plt.legend(loc=4) 124 | plt.grid(True) 125 | plt.tight_layout() 126 | 127 | path = os.path.join(path, model_name + '_loss.png') 128 | 129 | plt.savefig(path) 130 | 131 | plt.close() 132 | 133 | 134 | def loss_plot(x, args): 135 | save_txt = np.zeros((len(x.items()), len(x[0]))) 136 | for t, v in x.items(): 137 | save_txt[t] = np.array(v) 138 | plt.plot(list(range(t * args.epochs, (t + 1) * args.epochs)), v) 139 | # plt.plot(list(range(t, (t + 1))), v) 140 | plt.savefig(os.path.join(args.log_dir, args.task_type + '_' + args.method + "_loss_figure.png")) 141 | plt.clf() 142 | np.savetxt(os.path.join(args.log_dir, args.task_type + '_' + args.method + "_loss.txt"), save_txt) 143 | 144 | 145 | def accuracy_all_plot(x, args): 146 | save_txt = np.zeros((len(x.items()), len(x[0]))) 147 | 148 | for t, v in x.items(): 149 | save_txt[t, :len(v)] = np.array(v) 150 | 151 | plt.plot(range(len(x.items()) * len(x[0])), save_txt.reshape(len(x.items()) * len(x[0]))) 152 | plt.savefig(os.path.join(args.log_dir, args.task_type + '_' + args.method + "_overall_accuracy_figure.png")) 153 | plt.clf() 154 | np.savetxt(os.path.join(args.log_dir, args.task_type + '_' + args.method + "_overall_accuracy.txt"), save_txt) 155 | 156 | max = save_txt.max() 157 | 158 | np.savetxt(os.path.join(args.log_dir, args.task_type + '_' + args.method + "_best_overall.txt"), [max]) 159 | 160 | def accuracy_plot(x, args): 161 | save_txt = np.ones((len(x.items()), len(x[0]))) * -1 162 | for t, v in x.items(): 163 | save_txt[t, :len(v)] = np.array(v) 164 | plt.plot(list(range(t * args.epochs, args.num_task * args.epochs)), v) 165 | plt.ylim(0, 1) 166 | plt.savefig(os.path.join(args.log_dir, args.task_type + '_' + args.method + "_accuracy_figure.png")) 167 | plt.clf() 168 | np.savetxt(os.path.join(args.log_dir, args.task_type + '_' + args.method + "_accuracy.txt"), save_txt) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from Classifiers.Fashion_Classifier import Fashion_Classifier 8 | from Classifiers.Mnist_Classifier import Mnist_Classifier 9 | from Classifiers.Cifar_Classifier import Cifar_Classifier 10 | from Data.data_loader import DataLoader 11 | from Evaluation.Reviewer import Reviewer 12 | from Generative_Models.BEGAN import BEGAN 13 | from Generative_Models.CGAN import CGAN 14 | from Generative_Models.CVAE import CVAE 15 | from Generative_Models.GAN import GAN 16 | from Generative_Models.VAE import VAE 17 | from Generative_Models.WGAN import WGAN 18 | from Generative_Models.WGAN_GP import WGAN_GP 19 | from Training.Baseline import Baseline 20 | from Training.Ewc import Ewc 21 | from Training.Ewc_samples import Ewc_samples 22 | from Training.Generative_Replay import Generative_Replay 23 | from Training.Rehearsal import Rehearsal 24 | from log_utils import log_test_done 25 | from utils import check_args 26 | 27 | from Evaluation.Eval_Classifier import Reviewer_C 28 | 29 | """parsing and configuration""" 30 | 31 | 32 | def parse_args(): 33 | desc = "Pytorch implementation of GAN collections" 34 | parser = argparse.ArgumentParser(description=desc) 35 | 36 | parser.add_argument('--gan_type', type=str, default='CVAE', 37 | choices=['GAN', 'Classifier', 'CGAN', 'BEGAN', 'WGAN', 38 | 'WGAN_GP', 'VAE', "CVAE", "WGAN_GP"], 39 | help='The type of GAN') # , required=True) 40 | parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', 'fashion', 'cifar10'], 41 | help='The name of dataset') 42 | parser.add_argument('--conditional', type=bool, default=False) 43 | parser.add_argument('--upperbound', type=bool, default=False, 44 | help='This variable will be set to true automatically if task_type contains_upperbound') 45 | parser.add_argument('--method', type=str, default='Baseline', choices=['Baseline', 'Ewc', 'Ewc_samples', 46 | 'Generative_Replay', 'Rehearsal']) 47 | 48 | parser.add_argument('--context', type=str, default='Generation', 49 | choices=['Classification', 'Generation', 'Not_Incremental']) 50 | 51 | parser.add_argument('--dir', type=str, default='./Archives/', help='Working directory') 52 | parser.add_argument('--save_dir', type=str, default='models', help='Directory name to save the model') 53 | parser.add_argument('--result_dir', type=str, default='results', help='Directory name to save results') 54 | parser.add_argument('--sample_dir', type=str, default='Samples', help='Directory name to save the generated images') 55 | parser.add_argument('--log_dir', type=str, default='logs', help='Directory name to save training logs') 56 | parser.add_argument('--data_dir', type=str, default='Data', help='Directory name for data') 57 | parser.add_argument('--gen_dir', type=str, default='.', help='Directory name for data') 58 | 59 | parser.add_argument('--epochs', type=int, default=1, help='The number of epochs to run') 60 | parser.add_argument('--epoch_G', type=int, default=1, help='The number of epochs to run') 61 | parser.add_argument('--epoch_Review', type=int, default=50, help='The number of epochs to run') 62 | parser.add_argument('--batch_size', type=int, default=64, help='The size of batch') 63 | parser.add_argument('--size_epoch', type=int, default=1000) 64 | parser.add_argument('--gpu_mode', type=bool, default=True) 65 | parser.add_argument('--device', type=int, default=0) 66 | parser.add_argument('--verbose', type=bool, default=False) 67 | 68 | parser.add_argument('--lrG', type=float, default=0.0002) 69 | parser.add_argument('--lrD', type=float, default=0.0002) 70 | parser.add_argument('--lrC', type=float, default=0.01) 71 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') 72 | parser.add_argument('--beta1', type=float, default=0.5) 73 | parser.add_argument('--beta2', type=float, default=0.999) 74 | 75 | parser.add_argument('--seed', type=int, default=1664) 76 | parser.add_argument('--eval', type=bool, default=True) 77 | parser.add_argument('--train_G', type=bool, default=False) 78 | parser.add_argument('--eval_C', type=bool, default=False) 79 | 80 | ############### UNUSED FLAGS ########################## 81 | parser.add_argument('--trainEval', type=bool, default=False) 82 | parser.add_argument('--knn', type=bool, default=False) 83 | parser.add_argument('--IS', type=bool, default=False) 84 | parser.add_argument('--FID', type=bool, default=False) 85 | parser.add_argument('--Fitting_capacity', type=bool, default=False) 86 | ####################################################### 87 | 88 | parser.add_argument('--num_task', type=int, default=10) 89 | parser.add_argument('--num_classes', type=int, default=10) 90 | parser.add_argument('--sample_transfer', type=int, default=5000) 91 | parser.add_argument('--task_type', type=str, default="disjoint", 92 | choices=['disjoint', 'permutations', 'upperbound_disjoint']) 93 | parser.add_argument('--samples_per_task', type=int, default=200) 94 | parser.add_argument('--lambda_EWC', type=int, default=5) 95 | parser.add_argument('--nb_samples_reharsal', type=int, default=10) 96 | parser.add_argument('--regenerate', type=bool, default=False) 97 | 98 | return check_args(parser.parse_args()) 99 | 100 | 101 | """main""" 102 | 103 | 104 | def main(): 105 | # parse arguments 106 | args = parse_args() 107 | 108 | if args is None: 109 | exit() 110 | 111 | seed = args.seed 112 | torch.manual_seed(seed) 113 | np.random.seed(seed) 114 | args.gpu_mode = torch.cuda.is_available() 115 | 116 | if args.gpu_mode: 117 | torch.backends.cudnn.deterministic = True 118 | torch.cuda.manual_seed_all(args.seed) 119 | 120 | if args.context == 'Generation': 121 | print("Generation : Use of model {} with dataset {}, seed={}".format(args.gan_type, args.dataset, args.seed)) 122 | elif args.context == 'Classification': 123 | print("Classification : Use of method {} with dataset {}, seed={}".format(args.method, args.dataset, args.seed)) 124 | 125 | if args.context == 'Generation': 126 | 127 | # declare instance for GAN 128 | if args.gan_type == 'GAN': 129 | model = GAN(args) 130 | elif args.gan_type == 'CGAN': 131 | model = CGAN(args) 132 | elif args.gan_type == 'VAE': 133 | model = VAE(args) 134 | elif args.gan_type == 'CVAE': 135 | model = CVAE(args) 136 | elif args.gan_type == 'WGAN': 137 | model = WGAN(args) 138 | elif args.gan_type == 'WGAN_GP': 139 | model = WGAN_GP(args) 140 | else: 141 | raise Exception("[!] There is no option for " + args.gan_type) 142 | elif args.context == 'Classification': 143 | if args.dataset == 'mnist': 144 | model = Mnist_Classifier(args) 145 | elif args.dataset == 'fashion': 146 | model = Fashion_Classifier(args) 147 | elif args.dataset == 'cifar10': 148 | model = Cifar_Classifier(args) 149 | else: 150 | print('Not implemented') 151 | 152 | reviewer = Reviewer(args) 153 | 154 | if args.method == 'Baseline': 155 | method = Baseline(model, args, reviewer) 156 | elif 'Ewc' in args.method: 157 | method = Ewc(model, args) 158 | elif 'Ewc_samples' in args.method: 159 | method = Ewc_samples(model, args) 160 | elif args.method == 'Generative_Replay': 161 | method = Generative_Replay(model, args) 162 | elif args.method == 'Rehearsal': 163 | method = Rehearsal(model, args) 164 | else: 165 | print('Method not implemented') 166 | 167 | if args.context == 'Classification': 168 | 169 | if args.eval_C: 170 | reviewer_C = Reviewer_C(args) 171 | list_values = [10, 50, 100, 200, 500, 1000, 5000, 10000] 172 | 173 | reviewer_C.review_all_tasks(args, list_values) 174 | else: 175 | method.run_classification_tasks() 176 | elif args.context == 'Generation': 177 | 178 | if args.train_G: 179 | method.run_generation_tasks() 180 | log_test_done(args, 'Intermediate') 181 | 182 | if args.regenerate: 183 | method.regenerate_datasets_for_eval() 184 | 185 | if args.Fitting_capacity and not args.train_G: 186 | 187 | reviewer = Reviewer(args) 188 | # In case the training training and evaluation are done separately 189 | if args.gpu_mode: 190 | torch.backends.cudnn.deterministic = True 191 | torch.cuda.manual_seed_all(args.seed) 192 | 193 | reviewer.review_all_tasks(args) 194 | 195 | if args.method == "Baseline" and not args.upperbound: 196 | # Baseline produce both one lower bound and one upperbound 197 | # it is not the same upperbound a the one trained for upperbound_disjoint 198 | reviewer.review_all_tasks(args, Best=True) 199 | 200 | if args.FID and not args.train_G: 201 | 202 | reviewer = Reviewer(args) 203 | # In case the training training and evaluation are done separately 204 | if args.gpu_mode: 205 | torch.backends.cudnn.deterministic = True 206 | torch.cuda.manual_seed_all(args.seed) 207 | 208 | reviewer.compute_all_tasks_FID(args) 209 | 210 | if args.method == "Baseline" and not args.upperbound: 211 | # Baseline produce both one lower bound and one upperbound 212 | # it is not the same upperbound a the one trained for upperbound_disjoint 213 | reviewer.compute_all_tasks_FID(args, Best=True) 214 | 215 | if args.trainEval: 216 | reviewer = Reviewer(args) 217 | reviewer.review_all_trainEval(args) 218 | 219 | if args.method == "Baseline" and not args.upperbound: 220 | # Baseline produce both one lower bound and one upperbound 221 | # it is not the same upperbound a the one trained for upperbound_disjoint 222 | reviewer.review_all_trainEval(args, Best=True) 223 | 224 | 225 | else: 226 | print('Not Implemented') 227 | 228 | log_test_done(args) 229 | 230 | 231 | if __name__ == '__main__': 232 | main() 233 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os, gzip, torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import scipy.misc 5 | import imageio 6 | import matplotlib as mpl 7 | 8 | mpl.use('Agg') 9 | import matplotlib.pyplot as plt 10 | from torchvision import datasets, transforms 11 | from torch.autograd import Variable 12 | 13 | """checking arguments""" 14 | 15 | 16 | def check_args(args): 17 | if "Ewc" in args.method: 18 | args.method = args.method + '_' + str(args.lambda_EWC) 19 | 20 | args.save_dir = os.path.join(args.dir, args.save_dir) 21 | # --save_dir 22 | if not os.path.exists(args.save_dir): 23 | os.makedirs(args.save_dir) 24 | 25 | args.result_dir = os.path.join(args.dir, args.result_dir) 26 | # --result_dir 27 | if not os.path.exists(args.result_dir): 28 | os.makedirs(args.result_dir) 29 | 30 | args.log_dir = os.path.join(args.dir, args.log_dir) 31 | # --result_dir 32 | if not os.path.exists(args.log_dir): 33 | os.makedirs(args.log_dir) 34 | 35 | args.sample_dir = os.path.join(args.dir, args.sample_dir) 36 | # --sample_dir 37 | if not os.path.exists(args.sample_dir): 38 | os.makedirs(args.sample_dir) 39 | 40 | args.data_dir = os.path.join(args.dir, args.data_dir) 41 | # --sample_dir 42 | if not os.path.exists(args.data_dir): 43 | os.makedirs(args.data_dir) 44 | 45 | args.gen_dir = os.path.join(args.data_dir, 'Generated') 46 | # --sample_dir 47 | if not os.path.exists(args.gen_dir): 48 | os.makedirs(args.gen_dir) 49 | 50 | # --epoch 51 | try: 52 | assert args.epoch >= 1 53 | except: 54 | print('number of epochs must be larger than or equal to one') 55 | 56 | # --batch_size 57 | try: 58 | assert args.batch_size >= 1 59 | except: 60 | print('batch size must be larger than or equal to one') 61 | 62 | if 'upperbound' in args.task_type: 63 | args.upperbound = True 64 | elif (not 'upperbound' in args.task_type) and args.upperbound: 65 | args.task_type = 'upperbound_' + args.task_type 66 | 67 | args.data_file = args.task_type + '_' + str(args.num_task) + '.pt' 68 | 69 | # 70 | if args.gan_type == "VAE" and args.conditional: 71 | args.gan_type = "CVAE" 72 | if args.gan_type == "GAN" and args.conditional: 73 | args.gan_type = "CGAN" 74 | 75 | if args.context == 'Generation': 76 | args.result_dir = os.path.join(args.result_dir, args.context, args.task_type, args.dataset, args.method, 77 | args.gan_type, 78 | 'Num_tasks_' + str(args.num_task), 79 | 'seed_' + str(args.seed)) 80 | args.save_dir = os.path.join(args.save_dir, args.context, args.task_type, args.dataset, args.method, 81 | args.gan_type, 82 | 'Num_tasks_' + str(args.num_task), 'seed_' + str(args.seed)) 83 | args.log_dir = os.path.join(args.log_dir, args.context, args.task_type, args.dataset, args.method, 84 | args.gan_type, 85 | 'Num_tasks_' + str(args.num_task), 'seed_' + str(args.seed)) 86 | args.sample_dir = os.path.join(args.sample_dir, args.context, args.task_type, args.dataset, args.gan_type, 87 | args.method, 88 | 'Num_tasks_' + str(args.num_task), 89 | 'seed_' + str(args.seed)) 90 | args.gen_dir = os.path.join(args.gen_dir, args.dataset, args.gan_type, args.task_type, args.method, 91 | 'Num_tasks_' + str(args.num_task), 92 | 'seed_' + str(args.seed)) 93 | 94 | elif args.context == 'Classification': 95 | args.result_dir = os.path.join(args.result_dir, args.context, args.dataset, args.method, 96 | 'Num_tasks_' + str(args.num_task), 97 | 'seed_' + str(args.seed)) 98 | args.save_dir = os.path.join(args.save_dir, args.context, args.dataset, args.method, 99 | 'Num_tasks_' + str(args.num_task), 'seed_' + str(args.seed)) 100 | args.log_dir = os.path.join(args.log_dir, args.context, args.dataset, args.method, 101 | 'Num_tasks_' + str(args.num_task), 'seed_' + str(args.seed)) 102 | args.sample_dir = os.path.join(args.sample_dir, args.context, args.dataset, args.task_type, 103 | 'Num_tasks_' + str(args.num_task), 104 | 'seed_' + str(args.seed)) 105 | 106 | if not os.path.exists(args.result_dir): 107 | os.makedirs(args.result_dir) 108 | if not os.path.exists(args.save_dir): 109 | os.makedirs(args.save_dir) 110 | if not os.path.exists(args.log_dir): 111 | os.makedirs(args.log_dir) 112 | if not os.path.exists(args.sample_dir): 113 | os.makedirs(args.sample_dir) 114 | if not os.path.exists(args.gen_dir): 115 | os.makedirs(args.gen_dir) 116 | 117 | if args.gan_type == "CVAE" or args.gan_type == "CGAN": 118 | args.conditional = True 119 | 120 | print("Model : ", args.gan_type) 121 | print("Dataset : ", args.dataset) 122 | print("Method : ", args.method) 123 | print("Seed : ", str(args.seed)) 124 | print("Context : ", args.context) 125 | 126 | if args.FID: 127 | print("Doing : FID") 128 | if args.train_G: 129 | print("Doing : Train_G") 130 | if args.Fitting_capacity: 131 | print("Doing : Fitting_capacity") 132 | 133 | return args 134 | 135 | 136 | def print_network(net): 137 | num_params = 0 138 | for param in net.parameters(): 139 | num_params += param.numel() 140 | print(net) 141 | print('Total number of parameters: %d' % num_params) 142 | 143 | 144 | def initialize_weights(net): 145 | for m in net.modules(): 146 | if isinstance(m, nn.Conv2d): 147 | m.weight.data.normal_(0, 0.02) 148 | m.bias.data.zero_() 149 | elif isinstance(m, nn.ConvTranspose2d): 150 | m.weight.data.normal_(0, 0.02) 151 | m.bias.data.zero_() 152 | elif isinstance(m, nn.Linear): 153 | m.weight.data.normal_(0, 0.02) 154 | m.bias.data.zero_() 155 | 156 | 157 | def variable(t: torch.Tensor, use_cuda=True, **kwargs): 158 | if torch.cuda.is_available() and use_cuda: 159 | t = t.cuda() 160 | return Variable(t, **kwargs) 161 | 162 | 163 | def load_datasets(args): 164 | print(args.data_file) 165 | 166 | train_file = args.data_file.replace('.pt', '_train.pt') 167 | test_file = args.data_file.replace('.pt', '_test.pt') 168 | 169 | data_train = torch.load(os.path.join(args.data_dir, 'Tasks', args.dataset, train_file)) 170 | data_test = torch.load(os.path.join(args.data_dir, 'Tasks', args.dataset, test_file)) 171 | 172 | n_inputs = data_train[0][1].size(1) 173 | n_outputs = 0 174 | for i in range(len(data_train)): 175 | n_outputs = max(n_outputs, data_train[i][2].max()) 176 | n_outputs = max(n_outputs, data_test[i][2].max()) 177 | 178 | return data_train, data_test, n_inputs, n_outputs + 1, len(data_train) 179 | --------------------------------------------------------------------------------