├── .gitignore ├── LICENSE ├── NOTICE.txt ├── README.md ├── approaches ├── DGMa.py ├── DGMw.py ├── DGMw_imnet.py └── __init__.py ├── cfg ├── __init__.py ├── cfg_imnet_dgmw.yml ├── cfg_mnist_dgma.yml ├── cfg_mnist_dgmw.yml ├── cfg_svhn_dgmw.yml └── load_config.py ├── dat └── .gitignore ├── dataloaders ├── __init__.py ├── cifar_10.py ├── split_MNIST.py └── split_SVHN.py ├── lib ├── data_converter.py ├── data_io.py └── data_manager.py ├── logs └── .gitignore ├── networks ├── __init__.py ├── net_DGMa.py ├── net_DGMw.py ├── net_DGMw_imnet.py └── resnet.py ├── outputs └── .gitignore ├── requierements.txt ├── run.py ├── run_DGMw_imagenet.py └── utils ├── __init__.py ├── evaluation.py ├── folder.py ├── inception.py ├── inception_score.py ├── logger.py ├── spectral_normalization.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | venv/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | SAP SAMPLE CODE LICENSE AGREEMENT 2 | 3 | Please scroll down and read the following SAP Sample Code License Agreement carefully ("Agreement"). By downloading, installing, or otherwise using the SAP sample code or any materials that accompany the sample code documentation (collectively, the "Sample Code"), You agree that this Agreement forms a legally binding agreement between You ("You" or "Your") and SAP SE, for and on behalf of itself and its subsidiaries and affiliates (as defined in Section 15 of the German Stock Corporation Act), and You agree to be bound by all of the terms and conditions stated in this Agreement. If You are trying to access or download the Sample Code on behalf of Your employer or as a consultant or agent of a third party (either "Your Company"), You represent and warrant that You have the authority to act on behalf of and bind Your Company to the terms of this Agreement and everywhere in this Agreement that refers to 'You' or 'Your' shall also include Your Company. If You do not agree to these terms, do not attempt to access or use the Sample Code. 4 | 5 | 1. LICENSE: Subject to the terms of this Agreement, SAP grants You a nonexclusive, non-transferable, non-sublicensable, revocable, royalty-free, limited license to use, copy, and modify the Sample Code solely for Your internal business purposes. 6 | 7 | 2. RESTRICTIONS: You must not use the Sample Code to: (a) impair, degrade or reduce the performance or security of any SAP products, services or related technology (collectively, "SAP Products"); (b) enable the bypassing or circumventing of SAP's license restrictions and/or provide users with access to the SAP Products to which such users are not licensed; or (c) permit mass data extraction from an SAP Product to a non-SAP Product, including use, modification, saving or other processing of such data in the non-SAP Product. Further, You must not: (i) provide or make the Sample Code available to any third party other than your authorized employees, contractors and agents (collectively, “Representatives”) and solely to be used by Your Representatives for Your own internal business purposes; ii) remove or modify any marks or proprietary notices from the Sample Code; iii) assign this Agreement, or any interest therein, to any third party; (iv) use any SAP name, trademark or logo without the prior written authorization of SAP; or (v) use the Sample Code to modify an SAP Product or decompile, disassemble or reverse engineer an SAP Product (except to the extent permitted by applicable law). You are responsible for any breach of the terms of this Agreement by You or Your Representatives. 3. INTELLECTUAL PROPERTY: SAP or its licensors retain all ownership and intellectual property rights in and to the Sample Code and SAP Products. In exchange for the right to use, copy and modify the Sample Code provided under this Agreement, You covenant not to assert any intellectual property rights in or to any of Your products, services, or related technology that are based on or incorporate the Sample Code against any individual or entity in respect of any current or future SAP Products. 8 | 9 | 4. SAP AND THIRD PARTY APIS: The Sample Code may include API (application programming interface) calls to SAP and third-party products or services. The access or use of the third-party products and services to which the API calls are directed may be subject to additional terms and conditions between you and SAP or such third parties. You (and not SAP) are solely responsible for understanding and complying with any additional terms and conditions that apply to the access or use of those APIs and/or third-party products and services. SAP does not grant You any rights in or to these APIs, products or services under this Agreement. 10 | 11 | 5. FREE AND OPEN SOURCE COMPONENTS: The Sample Code may include third party free or open source components ("FOSS Components"). You may have additional rights in such FOSS Components that are provided by the third party licensors of those components. 12 | 13 | 6. THIRD PARTY DEPENDENCIES: The Sample Code may require third party software dependencies ("Dependencies") for the use or operation of the Sample Code. These Dependencies may be identified by SAP in Maven POM files, documentation or by other means. SAP does not grant You any rights in or to such Dependencies under this Agreement. You are solely responsible for the acquisition, installation and use of such Dependencies. 14 | 15 | 7. WARRANTY: 16 | a) If You are located outside the US or Canada: AS THE SAMPLE CODE IS PROVIDED TO YOU FREE OF CHARGE, SAP DOES NOT GUARANTEE OR WARRANT ANY FEATURES OR QUALITIES OF THE SAMPLE CODE OR GIVE ANY UNDERTAKING WITH REGARD TO ANY OTHER QUALITY. NO SUCH WARRANTY OR UNDERTAKING SHALL BE IMPLIED BY YOU FROM ANY DESCRIPTION IN THE SAMPLE CODE OR ANY OTHER MATERIALS, COMMUNICATION OR ADVERTISEMENT. IN PARTICULAR, SAP DOES NOT WARRANT THAT THE SAMPLE CODE WILL BE AVAILABLE UNINTERRUPTED, ERROR FREE, OR PERMANENTLY AVAILABLE. ALL WARRANTY CLAIMS RESPECTING THE SAMPLE CODE ARE SUBJECT TO THE LIMITATION OF LIABILITY STIPULATED IN SECTION 8 BELOW. 17 | b) If You are located in the US or Canada: THE SAMPLE CODE IS LICENSED TO YOU "AS IS", WITHOUT ANY WARRANTY, ESCROW, TRAINING, MAINTENANCE, OR SERVICE OBLIGATIONS WHATSOEVER ON THE PART OF SAP. SAP MAKES NO EXPRESS OR IMPLIED WARRANTIES OR CONDITIONS OF SALE OF ANY TYPE WHATSOEVER, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY AND OF FITNESS FOR A PARTICULAR PURPOSE. IN PARTICULAR, SAP DOES NOT WARRANT THAT THE SAMPLE CODE WILL BE AVAILABLE UNINTERRUPTED, ERROR FREE, OR PERMANENTLY AVAILABLE. YOU ASSUME ALL RISKS ASSOCIATED WITH THE USE OF THE SAMPLE CODE, INCLUDING WITHOUT LIMITATION RISKS RELATING TO QUALITY, AVAILABILITY, PERFORMANCE, DATA LOSS, AND UTILITY IN A PRODUCTION ENVIRONMENT. 18 | c) For all locations: SAP DOES NOT MAKE ANY REPRESENTATIONS OR WARRANTIES IN RESPECT OF THIRD PARTY DEPENDENCIES, APIS, PRODUCTS AND SERVICES, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY AND OF FITNESS FOR A PARTICULAR PURPOSE. IN PARTICULAR, SAP DOES NOT WARRANT THAT THIRDPARTY DEPENDENCIES, APIS, PRODUCTS AND SERVICES WILL BE AVAILABLE, ERROR FREE, INTEROPERABLE WITH THE SAMPLE CODE, SUITABLE FOR ANY PARTICULAR PURPOSE OR NONINFRINGING. YOU ASSUME ALL RISKS ASSOCIATED WITH THE USE OF THIRD PARTY DEPENDENCIES, APIS, PRODUCTS AND SERVICES, INCLUDING WITHOUT LIMITATION RISKS RELATING TO QUALITY, AVAILABILITY, PERFORMANCE, DATA LOSS, UTILITY IN A PRODUCTION ENVIRONMENT, AND NON-INFRINGEMENT. IN NO EVENT WILL SAP BE LIABLE DIRECTLY OR INDIRECTLY IN RESPECT OF ANY USE OF THIRD PARTY DEPENDENCIES, APIS, PRODUCTS AND SERVICES BY YOU. 19 | 20 | 8. LIMITATION OF LIABILITY: 21 | a) If You are located outside the US or Canada: IRRESPECTIVE OF THE LEGAL REASONS, SAP SHALL ONLY BE LIABLE FOR DAMAGES UNDER THIS AGREEMENT IF SUCH DAMAGE (I) CAN BE CLAIMED UNDER THE GERMAN PRODUCT LIABILITY ACT OR (II) IS CAUSED BY INTENTIONAL MISCONDUCT OF SAP OR (III) CONSISTS OF PERSONAL INJURY. IN ALL OTHER CASES, NEITHER SAP NOR ITS EMPLOYEES, AGENTS AND SUBCONTRACTORS SHALL BE LIABLE FOR ANY KIND OF DAMAGE OR CLAIMS HEREUNDER. 22 | b) If You are located in the US or Canada: IN NO EVENT SHALL SAP BE LIABLE TO YOU, YOUR COMPANY OR TO ANY THIRD PARTY FOR ANY DAMAGES IN AN AMOUNT IN EXCESS OF $100 ARISING IN CONNECTION WITH YOUR USE OF OR INABILITY TO USE THE SAMPLE CODE OR IN CONNECTION WITH SAP'S PROVISION OF OR FAILURE TO PROVIDE SERVICES PERTAINING TO THE SAMPLE CODE, OR AS A RESULT OF ANY DEFECT IN THE SAMPLE COED. THIS DISCLAIMER OF LIABILITY SHALL APPLY REGARDLESS OF THE FORM OF ACTION THAT MAY BE BROUGHT AGAINST SAP, WHETHER IN CONTRACT OR TORT, INCLUDING WITHOUT LIMITATION ANY ACTION FOR NEGLIGENCE. YOUR SOLE REMEDY IN THE EVENT OF BREACH OF THIS AGREEMENT BY SAP OR FOR ANY OTHER CLAIM RELATED TO THE SAMPLE CODE SHALL BE TERMINATION OF THIS AGREEMENT. NOTWITHSTANDING ANYTHING TO THE CONTRARY HEREIN, UNDER NO CIRCUMSTANCES SHALL SAP OR ITS LICENSORS BE LIABLE TO YOU OR ANY OTHER PERSON OR ENTITY FOR ANY SPECIAL, INCIDENTAL, CONSEQUENTIAL, OR INDIRECT DAMAGES, LOSS OF GOOD WILL OR BUSINESS PROFITS, WORK STOPPAGE, DATA LOSS, COMPUTER FAILURE OR MALFUNCTION, ANY AND ALL OTHER COMMERCIAL DAMAGES OR LOSS, OR EXEMPLARY OR PUNITIVE DAMAGES. 23 | 24 | 9. INDEMNITY: You will fully indemnify, hold harmless and defend SAP against law suits based on any claim: (a) that any of Your products, services or related technology that are based on or incorporate the Sample Code infringes or misappropriates any patent, copyright, trademark, trade secrets, or other proprietary rights of a third party, or (b) related to Your alleged violation of the terms of this Agreement. 25 | 26 | 10. EXPORT: The Sample Code is subject to German, EU and US export control regulations. You confirm that: a) You will not use the Sample Code for, and will not allow the Sample Code to be used for, any purposes prohibited by German, EU and US law, including, without limitation, for the development, design, manufacture or production of nuclear, chemical or biological weapons of mass destruction; b) You are not located in Cuba, Iran, Sudan, Iraq, North Korea, Syria, nor any other country to which the United States has prohibited export or that has been designated by the U.S. Government as a "terrorist supporting" country (any, an "US Embargoed Country"); c) You are not a citizen, national or resident of, and are not under the control of, a US Embargoed Country; d) You will not download or otherwise export or re-export the Sample Code, directly or indirectly, to a US Embargoed Country nor to citizens, nationals or residents of a US Embargoed Country; e) You are not listed on the United States Department of Treasury lists of Specially Designated Nationals, Specially Designated Terrorists, and Specially Designated Narcotic Traffickers, nor listed on the United States Department of Commerce Table of Denial Orders or any other U.S. government list of prohibited or restricted parties and f) You will not download or otherwise export or re-export the Sample Code, directly or indirectly, to persons on the above-mentioned lists. 27 | 28 | 11. SUPPORT: SAP does not offer support for the Sample Code. 29 | 30 | 12. TERM AND TERMINATION: You may terminate this Agreement by destroying all copies of the Sample Code in Your possession or control. SAP may terminate Your license to use the Sample Code immediately if You fail to comply with any of the terms of this Agreement, or, for SAP's convenience by providing you with ten (10) days written notice of termination. In case of termination or expiration of this Agreement, You must immediately destroy all copies of the Sample Code in your possession or control. In the event Your Company is acquired (by merger, purchase of stock, assets or intellectual property or exclusive license), or You become employed, by a direct competitor of SAP, then this Agreement and all licenses granted to You in this Agreement shall immediately terminate upon the date of such acquisition or change of employment. 31 | 32 | 13. LAW/VENUE: 33 | a) If You are located outside the US or Canada: This Agreement is governed by and construed in accordance with the laws of Germany without reference to its conflicts of law principles. You and SAP agree to submit to the exclusive jurisdiction of, and venue in, the courts located in Karlsruhe, Germany in any dispute arising out of or relating to this Agreement or the Sample Code. The United Nations Convention on Contracts for the International Sale of Goods shall not apply to this Agreement. 34 | b) If You are located in the US or Canada: This Agreement shall be governed by and construed in accordance with the laws of the State of New York, USA without reference to its conflicts of law principles. You and SAP agree to submit to the exclusive jurisdiction of, and venue in, the courts located in New York, New York, USA in any dispute arising out of or relating to this Agreement or the Sample Code. The United Nations Convention on Contracts for the International Sale of Goods shall not apply to this Agreement. 35 | 36 | 14. MISCELLANEOUS: This Agreement is the complete agreement between the parties respecting the Sample Code. This Agreement supersedes all prior or contemporaneous agreements or representations with regards to the Sample Code. If any term of this Agreement is found to be invalid or unenforceable, the surviving provisions shall remain effective. SAP's failure to enforce any right or provisions stipulated in this Agreement will not constitute a waiver of such provision, or any other provision of this Agreement. 37 | 38 | 39 | v1.0-071618 -------------------------------------------------------------------------------- /NOTICE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017-2019 SAP SE or an SAP affiliate company. All rights reserved. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![](https://img.shields.io/badge/STATUS-NOT%20CURRENTLY%20MAINTAINED-red.svg?longCache=true&style=flat) 2 | 3 | # Important Notice 4 | This public repository is read-only and no longer maintained. For the latest sample code repositories, visit the [SAP Samples](https://github.com/SAP-samples) organization. 5 | 6 | # Learning to Remember: A Synaptic Plasticity Driven Framework for Continual Learning 7 | ## Description: 8 | A continual learning framework for class incremental learning described in the following paper [arXiv](https://arxiv.org/abs/1904.03137). 9 | Note, this is work in progress and this code that will be dynamically updated. 10 | 11 | This repository currently contains code to run experiments of DGMw on three datasets: MIST, SVHN, ImageNet. 12 | ## Requirements 13 | 14 | Please, find a list with requiered packages and versions in [requierements.txt](https://github.com/SAP/machine-learning-dgm/blob/master/requierements.txt) file. 15 | 16 | 17 | ## How to obtain support 18 | This project is provided "as-is" and any bug reports are not guaranteed to be fixed. 19 | 20 | ## Running the tests 21 | In orer to start experiemtns, run the script passing the dataset name as argument (mnist/svhn): 22 | ``` 23 | python run.py --dataset mnist --method DGMw 24 | ``` 25 | Please, change the metaparmeters in the corresponding file [cfg/](https://github.com/SAP/machine-learning-dgm/tree/master/cfg) if needed. 26 | 27 | To run on the ImageNet dataset use the [run_DGMw_imagenet.py/](https://github.com/SAP/machine-learning-dgm/tree/master/run_DGMw_imagenet.py) script. 28 | ## License 29 | 30 | This project is licensed under SAP Sample Code License Agreement except as noted otherwise in the [LICENSE file](LICENSE.md). 31 | 32 | -------------------------------------------------------------------------------- /approaches/DGMa.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import copy 3 | import numpy as np 4 | import math 5 | from utils.logger import Logger 6 | import torch 7 | import pickle 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | from itertools import chain 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | import torch.utils.data 14 | import torchvision.utils as vutils 15 | from torch.autograd import Variable 16 | import torch.autograd as autograd 17 | from utils.utils import weights_init 18 | 19 | class App(object): 20 | def __init__(self, model, netG, netD, log_dir, outf, niter=100, batchSize=64, imageSize=64, nz=100, nb_label=10, 21 | cuda=True, beta1=0.5, lr_D=0.00002, lr_G=0.0002, lamb_G=1, reinit_D=False, 22 | lambda_adv=1, lambda_wassersten=10, dataset="mnist", device=None, store_model=False): 23 | 24 | 25 | self.store_model = store_model 26 | self.dataset = dataset 27 | self.device = device 28 | self.store_model = store_model 29 | 30 | self.lambda_adv = lambda_adv 31 | self.lambda_adv = lambda_adv 32 | self.lambda_wassersten = lambda_wassersten 33 | self.model = model 34 | self.netG = netG 35 | self.mask_histo = [] 36 | 37 | self.netD = netD 38 | self.log_dir = log_dir 39 | self.writer = Logger(log_dir) 40 | self.acc_writers = [] 41 | self.reinit_D = reinit_D 42 | self.outf = outf 43 | self.best_valid_acc = 0 44 | self.best_model_index = None 45 | self.best_selected_test_acc = 0 46 | self.niter = niter 47 | self.nb_label = nb_label 48 | self.nz = nz 49 | 50 | self.best_valid_acc = 0 51 | self.best_model_index = None 52 | self.best_selected_test_acc = 0 53 | 54 | self.lr_D = lr_D 55 | self.lr_G = lr_G 56 | self.beta1 = beta1 57 | self.n_reserver_prev = [0, 0, 0, 0, 0, 0, 0] 58 | 59 | self.lamb_G = lamb_G 60 | self.c_criterion = nn.CrossEntropyLoss() 61 | self.batchSize = batchSize 62 | self.imageSize = imageSize 63 | 64 | input_ = torch.FloatTensor(batchSize, 1, imageSize, imageSize) 65 | noise = torch.FloatTensor(batchSize, nz, 1, 1) 66 | fixed_noise = torch.FloatTensor(batchSize, nz, 1, 1).normal_(0, 1) 67 | s_label = torch.FloatTensor(batchSize) 68 | c_label = torch.LongTensor(batchSize) 69 | 70 | if cuda: 71 | self.netD.cuda(self.device) 72 | self.netG.cuda(self.device) 73 | self.c_criterion.cuda(self.device) 74 | input_, s_label = input_.cuda(self.device), s_label.cuda(self.device) 75 | c_label = c_label.cuda(self.device) 76 | noise, fixed_noise = noise.cuda(self.device), fixed_noise.cuda(self.device) 77 | 78 | self.input_ = Variable(input_) 79 | self.c_label = Variable(c_label) 80 | self.noise = Variable(noise) 81 | 82 | random_label = np.random.randint(0, nb_label, batchSize) 83 | print('fixed label:{}'.format(random_label)) 84 | 85 | # setup optimizer 86 | self.optimizerD = optim.Adam(self.netD.parameters(), lr=self.lr_D, betas=(self.beta1, 0.999)) 87 | self.optimizerG = optim.Adam(self.netG.parameters(), lr=self.lr_G, betas=(self.beta1, 0.999)) 88 | 89 | self.mask_pre_G = None 90 | self.mask_back_G = None 91 | self.n_reserver_1_prev = 0 92 | self.n_reserver_2_prev = 0 93 | self.n_reserver_3_prev = 0 94 | self.n_reserver_4_prev = 0 95 | self.global_step = 0 96 | self.unique_classes = [] 97 | self.free_size = self.netG.conv1.weight.shape[1] + self.netG.conv2.weight.shape[1] + \ 98 | self.netG.conv3.weight.shape[1] 99 | self.writer.scalar_summary('Total capacity Network (size)', self.free_size, 0) 100 | self.writer.scalar_summary('Total capacity Network (N parametrs)', 101 | np.prod(self.netG.conv1.weight.size()).item() + np.prod( 102 | self.netG.conv2.weight.size()).item() + np.prod( 103 | self.netG.conv3.weight.size()).item(), 0) 104 | 105 | return 106 | 107 | def train(self, data, t, thres_cosh=50, thres_emb=6, clipgrad=10000, smax_g=1e5, use_aux_G=False): 108 | self.best_valid_acc = 0 109 | self.best_selected_test_acc = 0 110 | self.mask_histo.append([[], [], []]) 111 | lr = self.lr_G 112 | self.netG.train() 113 | self.netD.train() 114 | total_size = self.netG.conv1.weight.shape[1] + self.netG.conv2.weight.shape[1] + self.netG.conv3.weight.shape[1] 115 | lamb_G = self.lamb_G * (total_size / self.free_size) 116 | print("lamb_G", lamb_G) 117 | 118 | # init writers for test accuracies 119 | log_dir_task = self.log_dir + "/Acc. Task " + str(t) 120 | self.acc_writers.append(Logger(log_dir_task)) 121 | test_acc_task = [] 122 | data_train_x = data[t]['train']['x'].data.clone() 123 | data_train_y = data[t]['train']['y'].data.clone() 124 | 125 | print('*' * 100) 126 | print("Training on task: ", t) 127 | print('*' * 100) 128 | self.unique_classes.append(torch.unique(data_train_y)) 129 | if t > 0: 130 | old_weights = self.netD.aux_linear.weight.data.clone() 131 | print(old_weights.shape) 132 | self.netD.aux_linear = nn.Linear(self.netD.output_size, old_weights.shape[0] + len( 133 | self.unique_classes[t])).cuda(self.device) 134 | self.netD.aux_linear.apply(weights_init) 135 | self.netD.aux_linear.weight.data[:old_weights.shape[0], :].copy_(copy.copy(old_weights.data.clone())) 136 | self.netG.last.append( 137 | self.model.ConvTranspose2d(self.netG.cap_conv3[t], self.netG.nc, 4, 2, 1, bias=False).cuda(self.device)) 138 | 139 | if t>0: 140 | self.netD.disc_linear.reset_parameters() 141 | self.netD.aux_linear.reset_parameters() 142 | 143 | if self.reinit_D and t>0: 144 | self.netD.apply(weights_init) 145 | self.netD.aux_linear.reset_parameters() 146 | self.netD.disc_linear.reset_parameters() 147 | self.optimizerD = optim.Adam(self.netD.parameters(), lr=self.lr_D, betas=(self.beta1, 0.999)) 148 | self.optimizerG = optim.Adam(self.netG.parameters(), lr=self.lr_G, betas=(self.beta1, 0.999)) 149 | 150 | if t > 0: 151 | print("Generating datasets") 152 | self.netG.eval() 153 | numb_samples = int(data[t]['train']['y'].shape[0]) 154 | print(numb_samples) 155 | for t_past in range(t): 156 | r = np.arange(numb_samples) 157 | r = torch.LongTensor(r) # .cuda() 158 | data[t_past]["train"]["x"] = None 159 | data[t_past]["train"]["y"] = None 160 | for c in self.unique_classes[t_past]: 161 | print(c) 162 | for i in range(0, len(r), self.batchSize): 163 | if i + self.batchSize <= len(r): 164 | b = r[i:i + self.batchSize] 165 | else: 166 | b = r[i:] 167 | self.c_label.data.resize_(len(b)).fill_(c) 168 | noise, radom_label = self.generate_noise(t_past, len(b), self.c_label.data.cpu().numpy()) 169 | if data[t_past]["train"]["y"] is None: 170 | data[t_past]["train"]["y"] = torch.LongTensor(radom_label).cpu().data.clone() 171 | else: 172 | data[t_past]["train"]["y"] = torch.cat( 173 | (data[t_past]["train"]["y"], torch.LongTensor(radom_label).cpu().data.clone())) 174 | 175 | img_gen, _ = self.netG(noise, t_past, self.c_label, smax_g) 176 | if data[t_past]["train"]["x"] is None: 177 | data[t_past]["train"]["x"] = img_gen.detach().cpu().data.clone() 178 | else: 179 | data[t_past]["train"]["x"] = torch.cat( 180 | (data[t_past]["train"]["x"], img_gen.detach().cpu().data.clone())) 181 | idx = np.random.permutation(data[t_past]["train"]["y"].shape[0]) 182 | data[t_past]["train"]["x"] = data[t_past]["train"]["x"][idx] 183 | data[t_past]["train"]["y"] = data[t_past]["train"]["y"][idx] 184 | print("*" * 100) 185 | print("Generating datasets finished") 186 | self.netG.train() 187 | 188 | print("*" * 100) 189 | try: 190 | for epoch in range(self.niter): 191 | self.netD.train() 192 | s_g_max = (smax_g - 1 / smax_g) * epoch / self.niter + 1 / smax_g 193 | print("s_g_max", s_g_max) 194 | self.write_log_epoch_start(t, epoch, s_g_max, lamb_G) 195 | r = np.arange(data_train_x.shape[0]) 196 | r = torch.LongTensor(r) # .cuda() 197 | 198 | for i in range(0, len(r), self.batchSize): 199 | if i + self.batchSize <= len(r): 200 | b = r[i:i + self.batchSize] 201 | else: 202 | b = r[i:] 203 | 204 | self.netD.zero_grad() 205 | self.netD.train() 206 | 207 | ########################### 208 | # (1) Update D network 209 | ########################### 210 | s_g = s_g_max 211 | # train with real 212 | img, label = data_train_x[b], data_train_y[b] 213 | aux_img = img 214 | aux_label = label 215 | if t > 0: 216 | for t_past in range(t): 217 | aux_img = torch.cat((aux_img, data[t_past]['train']['x'][b].detach())) 218 | aux_label = torch.cat((aux_label, data[t_past]['train']['y'][b])) 219 | 220 | idx = np.random.permutation(aux_img.shape[0]) 221 | aux_img = aux_img[idx] 222 | aux_label = aux_label[idx] 223 | aux_batch_size = aux_img.size(0) 224 | loss_G_aux = [] 225 | for bb in range(0, aux_batch_size, len(b)): 226 | img_b = aux_img[bb:bb + len(b)] 227 | bb_label = aux_label[bb:bb + len(b)] 228 | self.input_.data.resize_(img_b.size()).copy_(img_b.detach()) 229 | self.c_label.data.resize_(img_b.size(0)).copy_(bb_label) # fill with real class labels 230 | _, c_output = self.netD(self.input_) 231 | 232 | c_errD_real = self.c_criterion(c_output, self.c_label) 233 | loss_G_aux.append(c_errD_real) 234 | c_errD_real.backward() 235 | self.optimizerD.step() 236 | self.netD.zero_grad() 237 | 238 | self.netD.zero_grad() 239 | batch_size = img.size(0) 240 | self.input_.resize_(img.size()).copy_(img) 241 | 242 | s_output, _ = self.netD(self.input_) 243 | D_x = s_output.mean() 244 | s_errD_real = -D_x 245 | s_errD_real.backward() 246 | 247 | n_fake = batch_size 248 | self.c_label.data.resize_(batch_size).copy_(label) 249 | noise, radom_label = self.generate_noise(t, n_fake, self.c_label.data.cpu().numpy()) 250 | fake, masks_G = self.netG(noise, t, self.c_label, s_g) 251 | s_output_fake, _ = self.netD(fake.detach()) 252 | D_x_fake = s_output_fake.mean() 253 | errD_fake = D_x_fake # s_errD_fake # + c_errD_fake 254 | errD_fake.backward() 255 | gradient_penalty = self.calc_gradient_penalty(self.netD, self.input_, fake, batch_size) 256 | gradient_penalty.backward() 257 | errD = errD_fake - s_errD_real + gradient_penalty 258 | # torch.nn.utils.clip_grad_norm_(self.netD.parameters(), clipgrad) 259 | self.optimizerD.step() 260 | 261 | ########################### 262 | # (2) Update G network 263 | ########################### 264 | 265 | self.netG.zero_grad() 266 | noise, radom_label = self.generate_noise(t, n_fake, self.c_label.data.cpu().numpy()) 267 | fake, masks_G = self.netG(noise, t, self.c_label, s_g) 268 | s_output, c_output = self.netD(fake) 269 | source_l, mask_reg_l, _ = self.criterion(s_output, masks_G, lamb_G) 270 | c_errG = self.c_criterion(c_output, self.c_label) 271 | 272 | step = (int(math.floor(i / self.batchSize))) + ( 273 | int(math.floor(data_train_x.shape[0] / self.batchSize) + 1) * epoch) 274 | errG = -(source_l) + mask_reg_l 275 | if use_aux_G: 276 | errG += c_errG 277 | errG.backward() 278 | 279 | if t > 0: 280 | for n, p in self.netG.named_parameters(): 281 | if n in self.mask_back_G and p.grad is not None: 282 | p.grad.data *= self.mask_back_G[n] 283 | 284 | # Compensate embedding gradients 285 | for n, p in self.netG.named_parameters(): 286 | if "ec" in n: # .startswith('e'): 287 | # print(n) 288 | num = torch.cosh(torch.clamp(s_g * p.data, -thres_cosh, thres_cosh)) + 1 289 | den = torch.cosh(p.data) + 1 290 | if p.grad is not None: 291 | p.grad.data *= s_g_max / s_g * num / den 292 | 293 | # Apply step 294 | torch.nn.utils.clip_grad_norm_(self.netG.parameters(), clipgrad) 295 | self.optimizerG.step() 296 | self.netG.zero_grad() 297 | print('|[%d/%d][%d/%d] Loss_D: %.2f Loss_G: %.2f D(x): %.2f D(G(z)): %.2f / %.2f' 298 | % (epoch, self.niter, i / self.batchSize, data_train_x.shape[0] / self.batchSize, 299 | errD.data.item(), errG.data.item(), s_errD_real, errD_fake, source_l)) 300 | 301 | if epoch % 5 == 0: 302 | self.netG.eval() 303 | with torch.no_grad(): 304 | loss_valid, valid_acc, _ = self.valid(data, t, epoch, self.netD, "valid") 305 | loss, test_accs, _ = self.valid(data, t, epoch, self.netD, "test") 306 | test_acc_task.append(test_accs) 307 | print("-" * 100) 308 | 309 | norm = False 310 | if self.dataset == "svhn": 311 | norm = True 312 | if epoch % 10 == 0: 313 | vutils.save_image(aux_img, 314 | '%s/real_samples_task%d_epoch_%d.png' % (self.outf, t, epoch), 315 | normalize=norm) 316 | lables_noise = torch.FloatTensor(list(chain(*([x] * 40 for x in 317 | range(torch.min(self.unique_classes[t]), 318 | torch.max( 319 | self.unique_classes[t]) + 1))))) 320 | self.c_label.data.resize_(lables_noise.shape[0]).copy_(lables_noise) 321 | noise, radom_label = self.generate_noise(t, lables_noise.shape[0], 322 | self.c_label.data.cpu().numpy()) 323 | fake, _ = self.netG(noise, t, self.c_label, smax_g) 324 | vutils.save_image(fake.data, 325 | '%s/fake_samples__task_%d_epoch_%03d.png' % (self.outf, t, epoch), 326 | normalize=norm) 327 | if t > 0: 328 | for u in range(t + 1): 329 | lables_noise = torch.FloatTensor(list(chain(*([x] * 20 for x in range( 330 | torch.min(self.unique_classes[u]), 331 | torch.max(self.unique_classes[u]) + 1))))) 332 | self.c_label.data.resize_(lables_noise.shape[0]).copy_(lables_noise) 333 | noise, _ = self.generate_noise(u, lables_noise.shape[0], 334 | self.c_label.data.cpu().numpy()) 335 | fake, _ = self.netG(noise, u, self.c_label, smax_g) # s_g_max) 336 | vutils.save_image(fake.data, 337 | '%s/fake_samples_from_%d_task_%d_epoch_%03d.png' % ( 338 | self.outf, u, t, epoch), normalize=norm) 339 | self.write_log_epoch_end(t, epoch, s_g_max) 340 | self.netG.train() 341 | 342 | self.global_step += 1 343 | 344 | loss_valid, valid_acc, _ = self.valid(data, t, epoch, self.netD, "valid") 345 | loss, test_accs, conf_matrixes_task = self.valid(data, t, epoch, self.netD, "test") 346 | test_acc_task.append(test_accs) 347 | except KeyboardInterrupt: 348 | loss_valid, valid_acc, _ = self.valid(data, t, epoch, self.netD, "valid") 349 | loss, test_accs, conf_matrixes_task = self.valid(data, t, epoch, self.netD, "test") 350 | test_acc_task.append(test_accs) 351 | print() 352 | 353 | # Activations mask 354 | task = torch.autograd.Variable(torch.LongTensor([t]).cuda()) 355 | masks_G = self.netG.mask(task, s=smax_g) 356 | 357 | for i_ in range(len(masks_G)): 358 | masks_G[i_][masks_G[i_] >= 0.5] = 1 359 | masks_G[i_][masks_G[i_] < 0.5] = 0 360 | masks_G[i_] = torch.autograd.Variable(masks_G[i_].detach().data.clone(), requires_grad=False) 361 | 362 | if t == 0: 363 | self.mask_pre_G = copy.deepcopy(masks_G) 364 | else: 365 | for i_ in range(len(self.mask_pre_G)): 366 | self.mask_pre_G[i_] = torch.max(self.mask_pre_G[i_], masks_G[i_]) 367 | 368 | _, newly_used_cap = self.extand_layers(self.mask_pre_G, t) 369 | 370 | masks_G = self.netG.mask(task, s=smax_g) 371 | 372 | for i_ in range(len(masks_G)): 373 | masks_G[i_] = torch.autograd.Variable(masks_G[i_].data.clone(), requires_grad=False) 374 | self.mask_pre_G[i_] = F.pad(self.mask_pre_G[i_], 375 | [0, masks_G[i_].shape[1] - self.mask_pre_G[i_].shape[1]], "constant", 0) 376 | 377 | self.write_log_task_end(t, masks_G, newly_used_cap, smax_g) 378 | 379 | self.mask_back_G = {} 380 | for n, _ in self.netG.named_parameters(): 381 | vals = self.netG.get_view_for(n, self.mask_pre_G) 382 | if vals is not None: 383 | self.mask_back_G[n] = 1 - vals 384 | self.mask_back_G[n][self.mask_back_G[n] < 0.5] = 0 385 | self.mask_back_G[n][self.mask_back_G[n] >= 0.5] = 1 386 | 387 | #backup model and mask_histo 388 | if self.store_model: 389 | self.store_models(epoch,t) 390 | 391 | 392 | return test_acc_task, conf_matrixes_task, masks_G 393 | 394 | def store_models(self,epoch,t): 395 | self.netG.mask_back = self.mask_back_G 396 | self.netG.mask_pre = self.mask_pre_G 397 | torch.save(self.netG, '%s/netG_task_%d_epoch_%d.pth' % (self.outf + "/models", t, epoch)) 398 | torch.save(self.netD, '%s/netD_task_%d_epoch_%d.pth' % (self.outf + "/models", t, epoch)) 399 | with open(self.outf + '/mask_histo/' + str(t) + '.pickle', 'wb') as handle: 400 | pickle.dump(self.mask_histo, handle, protocol=pickle.HIGHEST_PROTOCOL) 401 | 402 | def generate_noise(self, c, batch_size, label): 403 | self.noise.data.resize_(batch_size, self.nz) 404 | noise_ = np.random.normal(0, 1, (batch_size, self.nz)) 405 | label_onehot = np.zeros((batch_size, self.nb_label)) 406 | label_onehot[np.arange(batch_size), label.astype(int)] = 1. 407 | noise_[np.arange(batch_size), :self.nb_label] = label_onehot[np.arange(batch_size)] 408 | 409 | noise_ = (torch.from_numpy(noise_)) 410 | noise_ = noise_.resize_(batch_size, self.nz) # , 1, 1) 411 | self.noise.data.copy_(noise_) 412 | 413 | return copy.copy(self.noise), label 414 | 415 | def write_log_epoch_start(self, t, epoch, smax_g, lamb_G): 416 | task = torch.autograd.Variable(torch.LongTensor([t]).cuda(self.device)) 417 | current_classes = self.unique_classes[t] 418 | self.c_label.data.resize_(current_classes.shape[0]).copy_(current_classes) 419 | 420 | masks_G = self.netG.mask(task, s=smax_g) 421 | total_cap = 0 422 | total_used = 0 423 | cap_string = "Mask capacity G: " 424 | print(len(masks_G)) 425 | print(masks_G[0].shape) 426 | for layer_n in range(len(masks_G)): 427 | cap = torch.sum(masks_G[layer_n]).cpu().data.numpy() / np.prod(masks_G[layer_n].size()).item() 428 | n_total_l = int(masks_G[layer_n].shape[1]) 429 | cap_string += " " + str(cap) 430 | total_cap += n_total_l 431 | total_used += torch.sum(masks_G[layer_n]).cpu().data.numpy() 432 | self.writer.histo_summary("task_%s/L_%s_mask_distribution" % (t, layer_n), 433 | masks_G[layer_n].squeeze(0).cpu().data.numpy(), epoch) 434 | self.mask_histo[t][layer_n].append(masks_G[layer_n].squeeze(0).cpu().data.numpy()) 435 | self.writer.scalar_summary('lamb_G', lamb_G, t) 436 | print(cap_string) 437 | 438 | def calc_gradient_penalty(self, netD, real_data, fake_data, BATCH_SIZE): 439 | LAMBDA = self.lambda_wassersten 440 | DIM = 32 441 | alpha = torch.rand(BATCH_SIZE, 1) 442 | alpha = alpha.expand(BATCH_SIZE, int(real_data.nelement() / BATCH_SIZE)).contiguous() 443 | alpha = alpha.view(BATCH_SIZE, self.netD.nc, DIM, DIM).cuda(self.device) 444 | interpolates = alpha * real_data.cuda(self.device) + ((1 - alpha) * fake_data.cuda(self.device)) 445 | 446 | # if use_cuda: 447 | interpolates = interpolates.cuda(self.device) 448 | interpolates = Variable(interpolates, requires_grad=True) 449 | 450 | disc_interpolates, _ = netD(interpolates) 451 | 452 | gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, 453 | grad_outputs=torch.ones(disc_interpolates.size()).cuda(self.device), create_graph=True, 454 | retain_graph=True, only_inputs=True)[0] 455 | 456 | gradients = gradients.view(gradients.size(0), -1) 457 | 458 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA 459 | return gradient_penalty 460 | 461 | def write_log_epoch_end(self, t, epoch, smax_g): 462 | task = torch.autograd.Variable(torch.LongTensor([t]).cuda(self.device)) 463 | current_classes = self.unique_classes[t] 464 | self.c_label.data.resize_(current_classes.shape[0]).copy_(current_classes) 465 | masks_G = self.netG.mask(task, s=smax_g) 466 | 467 | total_cap = 0 468 | total_used = 0 469 | cap_string = "Mask capacity G: " 470 | for layer_n in range(len(masks_G)): 471 | print(layer_n) 472 | cap = torch.sum(masks_G[layer_n]).cpu().data.numpy() / np.prod(masks_G[layer_n].size()).item() 473 | cap_string += " " + str(cap) 474 | n_total_l = int(masks_G[layer_n].shape[1]) 475 | total_cap += n_total_l 476 | total_used += torch.sum(masks_G[layer_n]).cpu().data.numpy() 477 | # self.writer.scalar_summary('task_%s/L_%s_mask_capacities'%(t, layer_n), cap, epoch) 478 | # self.writer.scalar_summary('Total capacity L_%s'%(layer_n), n_total_l, t) 479 | self.writer.histo_summary("task_%s/L_%s_mask_distribution" % (t, layer_n), 480 | masks_G[layer_n].squeeze(0).cpu().data.numpy(), epoch) 481 | self.writer.scalar_summary('task_%s/Total_used_capacity' % (t), total_used / total_cap, epoch) 482 | print(cap_string) 483 | 484 | def write_log_task_end(self, t, masks_G, newly_used_cap, smax_g): 485 | n_free = 0 486 | reused = 0 487 | used_ever = 0 488 | used_last_task = 0 489 | l_reu_sum = 0 490 | for layer_n in range(len(masks_G)): 491 | n_free += torch.sum(self.mask_pre_G[layer_n] == 0) 492 | layer_mask_acc = masks_G[layer_n].data.clone() 493 | for tt in range(t): 494 | task_prev = torch.autograd.Variable(torch.LongTensor([tt]).cuda()) 495 | mask_prev = self.netG.mask(task_prev, s=smax_g) 496 | layer_mask_acc += mask_prev[layer_n] # [layer_mask_acc>0] 497 | 498 | l = layer_mask_acc.data.cpu().numpy() 499 | l_reu = (np.mean(l[l > 0])) 500 | l_reu_sum += l_reu 501 | reused += torch.sum(layer_mask_acc > 1) 502 | used_ever += torch.sum(layer_mask_acc > 0) 503 | used_last_task += torch.sum(masks_G[layer_n] > 0) 504 | # log used capacity new 505 | self.writer.scalar_summary('Newly blocked capacity(% of free)', 506 | (sum(newly_used_cap) / n_free.data.cpu().numpy()) * 100., t) 507 | # log amount of free parameters - should be constant 508 | self.writer.scalar_summary('Free neurons (N)', copy.deepcopy(n_free), t) 509 | 510 | self.writer.scalar_summary('Newly blocked capacity(absolute)', sum(newly_used_cap), t) 511 | self.writer.scalar_summary('Neurons used for task (N)', used_last_task.data.cpu().numpy(), t) 512 | self.writer.scalar_summary('Newly blocked capacity(% of used for task)', 513 | (sum(newly_used_cap) / used_last_task.data.cpu().numpy()) * 100., t) 514 | self.writer.scalar_summary('Reused capacity (% of used for task)', ((used_last_task.data.cpu().numpy() - sum( 515 | newly_used_cap)) / used_last_task.data.cpu().numpy()) * 100., t) 516 | self.writer.scalar_summary('Reused capacity (of used for task)', 517 | ((used_last_task.data.cpu().numpy() - sum(newly_used_cap))), t) 518 | # average number of tasks neurons are reused for 519 | self.writer.scalar_summary('Average reusability (N tasks)', l_reu_sum / len(masks_G), t) 520 | self.writer.scalar_summary('Total capacity Network (size)', 521 | self.netG.conv1.weight.shape[1] + self.netG.conv2.weight.shape[1] + 522 | self.netG.conv3.weight.shape[1], t + 1) 523 | 524 | self.writer.scalar_summary('Total capacity Network (N parametrs)', 525 | np.prod(self.netG.conv1.weight.size()).item() + np.prod( 526 | self.netG.conv2.weight.size()).item() + np.prod( 527 | self.netG.conv3.weight.size()).item(), t + 1) 528 | 529 | def extand_layers(self, masks_G, t): 530 | # addig neurons to keep free capacity constant 531 | extantion = [] 532 | for layer_n in range(len(masks_G)): 533 | n_reserver = int(torch.sum(masks_G[layer_n] == 1).data.cpu().numpy()) - self.n_reserver_prev[layer_n] 534 | self.n_reserver_prev[layer_n] += n_reserver 535 | extantion.append(n_reserver) 536 | current_weight_shapes = self.netG.extand(t, extantion, self.netG.smax) 537 | 538 | return current_weight_shapes, extantion 539 | 540 | def accuracy(self, output, target): 541 | val, max_ = output.max(1) 542 | hits = (max_ == target).float() 543 | acc = torch.sum(hits).data.cpu().numpy() / target.shape[0] 544 | return acc, max_.data.cpu().numpy() 545 | 546 | def valid(self, data, t_max, epoch, net, split="valid"): 547 | # self.netG.eval() 548 | self.netD.eval() 549 | test_accs = [] 550 | confusion = None # 551 | np.zeros((t_max + 1, t_max + 1)) 552 | acc_av = 0 553 | loss = 0 554 | correct_labels = [] 555 | predict_labels = [] 556 | with torch.no_grad(): 557 | for tt in range(t_max + 1): 558 | total_acc = 0 559 | total_num = 0 560 | r_valid = np.arange(data[tt][split]['x'].shape[0]) 561 | r_valid = torch.LongTensor(r_valid).cuda(self.device) 562 | print("-" * 100) 563 | # true x pred 564 | for ii in range(0, len(r_valid), self.batchSize): 565 | if ii + self.batchSize <= len(r_valid): 566 | b_val = r_valid[ii:ii + self.batchSize] 567 | else: 568 | b_val = r_valid[ii:] 569 | img_valid, label_valid = data[tt][split]['x'][b_val], data[tt][split]['y'][b_val] 570 | self.input_.data.resize_(img_valid.size()).copy_(img_valid) 571 | correct_labels += list(label_valid) 572 | _, c_output_valid = net(self.input_) 573 | loss += self.c_criterion(c_output_valid, label_valid.cuda(self.device)) 574 | c_output_valid = torch.nn.functional.log_softmax(c_output_valid, dim=1) 575 | acc_, pred = self.accuracy(c_output_valid, label_valid.cuda(self.device)) 576 | total_acc += acc_ 577 | total_num += 1 578 | predict_labels += list(pred) 579 | acc = total_acc / total_num 580 | acc_av += acc 581 | test_accs.append(acc) 582 | print('| '+split+' on task:{:d} : acc={:.1f}% |'.format(tt, 100. * acc), end='\n') 583 | self.acc_writers[tt].scalar_summary("Accuracy_"+split, 100. * acc, self.global_step) 584 | acc_av = (100. * acc_av) / (t_max + 1) 585 | self.writer.scalar_summary("Average_Acc. "+split, acc_av, self.global_step) 586 | 587 | if split == "valid": 588 | if self.best_valid_acc < acc_av: 589 | self.best_valid_acc = acc_av 590 | self.best_model_index = self.global_step 591 | self.writer.scalar_summary("Best acc " + split, acc_av, t_max) 592 | elif split == "test": 593 | if self.best_model_index is not None and self.best_model_index == self.global_step: 594 | self.writer.scalar_summary("Best acc " + split, acc_av, t_max) 595 | self.best_selected_test_acc = acc_av 596 | print("*"*100) 597 | print('| Best selected average acc. after task {:d} sofar: acc={:.1f}% |'.format(t_max, self.best_selected_test_acc), end='\n') 598 | self.netD.train() 599 | return loss, acc_av, confusion 600 | 601 | def criterion(self, y_hat, masks, lamb_G): 602 | reg = 0 603 | count = 0 604 | if self.mask_pre_G is not None: 605 | for m, mp in zip(masks, self.mask_pre_G): 606 | aux = 1 - mp 607 | reg += (m * aux).sum() 608 | count += aux.sum() 609 | else: 610 | for m in masks: 611 | reg += m.sum() 612 | count += np.prod(m.size()).item() 613 | reg /= count 614 | return self.lambda_adv * (y_hat.mean()), lamb_G * reg, reg 615 | 616 | 617 | 618 | -------------------------------------------------------------------------------- /approaches/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAP-archive/machine-learning-dgm/78786f0d9469cba201ad0108e4af2387574dc7c0/approaches/__init__.py -------------------------------------------------------------------------------- /cfg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAP-archive/machine-learning-dgm/78786f0d9469cba201ad0108e4af2387574dc7c0/cfg/__init__.py -------------------------------------------------------------------------------- /cfg/cfg_imnet_dgmw.yml: -------------------------------------------------------------------------------- 1 | method: 'DGMw' 2 | dataset: 'imagenet' 3 | log_dir: '/home/ec2-user/imagenet_mnt/logs/logs_imnet/' 4 | dataroot: '/home/ec2-user/imagenet_mnt/train/images/folder/' 5 | dataroot_val: '/home/ec2-user/imagenet_mnt/val/images/folder/' 6 | outf: '/home/ec2-user/imagenet_mnt/final_imgnet/DGMw/outputs_32/imgs' 7 | outf_models: '/home/ec2-user/imagenet_mnt/final_imgnet/DGMw/outputs_32/models' 8 | batchSize: 70 #128 #70 9 | imageSize: 32 10 | workers: 5 11 | manualSeed: 100 12 | nz: 128 # size of the latent z vector 13 | ngf: 64 #64 #50 #50 14 | ndf: 128 15 | niter: 101 #61 #121 16 | lr_D: 0.0002 17 | lr_G: 0.0002 18 | beta1: 0.5 19 | cuda: True 20 | device: 0 21 | device_D: 'cuda:0' 22 | device_G: 'cuda:0' 23 | smax_g: !!float 1e4 24 | lamb_G: 0.001 25 | lambda_adv: 1. 26 | lambda_wasserstein: !!float 0.6 27 | reinit_D : False 28 | store_models: False 29 | aux_G: True 30 | calc_fid_imnet: False 31 | class_idx_imnet: [1,15,29,45,59,65,81,89,90,99] #every new task increases by 100 -------------------------------------------------------------------------------- /cfg/cfg_mnist_dgma.yml: -------------------------------------------------------------------------------- 1 | method: 'DGMa' 2 | dataset: 'mnist' 3 | log_dir: './logs/DGMa/mnist/' 4 | dataroot: 'dat/split_mnist_' 5 | outf: 'outputs/DGMa_mnist' 6 | outf_models: 'outputs/DGMa_mnist/models' 7 | batchSize: 64 8 | imageSize: 32 9 | manualSeed: 0 10 | nz: 128 # size of the latent z vector 11 | ngf: 20 12 | ndf: 32 13 | niter: 101 14 | lr_D: 0.002 15 | lr_G: 0.002 16 | beta1: 0.5 17 | cuda: True 18 | device: 1 19 | smax_g: !!float 1e5 #800 20 | lamb_G: 0.01 21 | lambda_adv: 1. 22 | lambda_wasserstein: 1. 23 | reinit_D : True 24 | store_models: False 25 | aux_G: False -------------------------------------------------------------------------------- /cfg/cfg_mnist_dgmw.yml: -------------------------------------------------------------------------------- 1 | method: 'DGMw' 2 | dataset: 'mnist' 3 | log_dir: './logs/DGMw/mnist/' 4 | dataroot: 'dat/split_mnist_' 5 | outf: 'outputs/DGMw' 6 | outf_models: 'outputs/DGMw/models' 7 | batchSize: 64 8 | imageSize: 32 9 | manualSeed: 0 10 | nz: 128 # size of the latent z vector 11 | ngf: 6 12 | ndf: 32 13 | niter: 101 14 | lr_D: 0.002 15 | lr_G: 0.002 16 | beta1: 0.5 17 | cuda: True 18 | device: 0 19 | smax_g: !!float 1e4 #1e5 20 | lamb_G: 0.08 #1. 21 | lambda_adv: 1. 22 | lambda_wasserstein: 1. 23 | reinit_D : True 24 | store_models: False 25 | aux_G: False #True -------------------------------------------------------------------------------- /cfg/cfg_svhn_dgmw.yml: -------------------------------------------------------------------------------- 1 | method: 'DGMw' 2 | dataset: 'svhn' 3 | log_dir: './logs/DGMw/svhn/' 4 | dataroot: 'dat/SVHN_' 5 | outf: 'outputs/DGMw' 6 | outf_models: 'outputs/DGMw/models' 7 | batchSize: 64 8 | imageSize: 32 9 | manualSeed: 100 10 | nz: 128 # size of the latent z vector 11 | ngf: 6 12 | ndf: 32 13 | niter: 151 14 | lr_D: 0.0005 15 | lr_G: 0.0005 16 | beta1: 0.5 17 | cuda: True 18 | device: 0 19 | smax_g: !!float 1e4 #1e5 20 | lamb_G: 0.001 21 | lambda_adv: 1. 22 | lambda_wasserstein: 1. 23 | reinit_D : True 24 | store_models: False 25 | aux_G: False -------------------------------------------------------------------------------- /cfg/load_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from easydict import EasyDict as edict 3 | __C = edict() 4 | opt = __C 5 | 6 | __C.method = 'DGMw' 7 | __C.dataset = 'mnist' 8 | __C.log_dir = '/home/ec2-user/imagenet_mnt/std_dgm/logs/DGMw/std_runs/' 9 | __C.dataroot= 'dat/split_mnist_' 10 | __C.outf='/home/ec2-user/imagenet_mnt/std_dgm/outputs/DGMw' 11 | __C.outf_models='outputs/DGMw/models' 12 | __C.batchSize= 64 13 | __C.imageSize= 32 14 | __C.manualSeed= 2 15 | __C.nz= 128 # size of the latent z vector 16 | __C.ngf= 6 17 | __C.ndf= 32 18 | __C.niter= 251 19 | __C.lr_D= 0.002 20 | __C.lr_G= 0.002 21 | __C.beta1= 0.5 22 | __C.cuda= True 23 | __C.device= 3 24 | __C.manualSeed= 100 25 | __C.smax_g= 1e+5 26 | __C.lamb_G= 0.08 27 | __C.lambda_adv= 1. 28 | __C.lambda_wasserstein= 1. 29 | __C.reinit_D = True 30 | __C.store_models= False 31 | __C.aux_G=False 32 | 33 | 34 | __C.workers=1 35 | __C.dataroot_val='' 36 | #__C.nruns=10 37 | __C.device_G = 'cuda:0' 38 | __C.device_D = 'cuda:0' 39 | #__C.gpus= '3,4,5,6,7' 40 | #__C.nproc_gpu=2 41 | #__C.sleep=0.9 42 | #__C.tmp_folder='/home/ec2-user/imagenet_mnt/std_dgm/tmp/' 43 | __C.calc_fid_imnet = False 44 | __C.class_idx_imnet = [] 45 | 46 | 47 | def _merge_a_into_b(a, b): 48 | """Merge config dictionary a into config dictionary b, clobbering the 49 | options in b whenever they are also specified in a. 50 | """ 51 | if type(a) is not edict: 52 | return 53 | 54 | for k, v in a.items(): 55 | # a must specify keys that are in b 56 | if not k in b: 57 | raise KeyError('{} is not a valid config key'.format(k)) 58 | 59 | # the types must match, too 60 | old_type = type(b[k]) 61 | if old_type is not type(v): 62 | if isinstance(b[k], np.ndarray): 63 | v = np.array(v, dtype=b[k].dtype) 64 | else: 65 | raise ValueError(('Type mismatch ({} vs. {}) ' 66 | 'for config key: {}').format(type(b[k]), 67 | type(v), k)) 68 | 69 | # recursively merge dicts 70 | if type(v) is edict: 71 | try: 72 | _merge_a_into_b(a[k], b[k]) 73 | except: 74 | print('Error under config key: {}'.format(k)) 75 | raise 76 | else: 77 | b[k] = v 78 | 79 | 80 | def cfg_from_file(filename): 81 | """Load a config file and merge it into the default options.""" 82 | import yaml 83 | with open(filename, 'r') as f: 84 | yaml_cfg = edict(yaml.load(f)) 85 | 86 | _merge_a_into_b(yaml_cfg, __C) -------------------------------------------------------------------------------- /dat/.gitignore: -------------------------------------------------------------------------------- 1 | [^.]* -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAP-archive/machine-learning-dgm/78786f0d9469cba201ad0108e4af2387574dc7c0/dataloaders/__init__.py -------------------------------------------------------------------------------- /dataloaders/cifar_10.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import numpy as np 3 | import torch 4 | import utils 5 | from torchvision import datasets,transforms 6 | from sklearn.utils import shuffle 7 | 8 | def get(seed=0,pc_valid=0.10): 9 | data={} 10 | taskcla=[] 11 | size=[3,32,32] 12 | if not os.path.isdir('../dat/binary_cifar_10/'): 13 | os.makedirs('../dat/binary_cifar_10') 14 | 15 | mean=[x/255 for x in [125.3,123.0,113.9]] 16 | std=[x/255 for x in [63.0,62.1,66.7]] 17 | #mean = [0.485, 0.456, 0.406] 18 | #std = [0.229, 0.224, 0.225] 19 | 20 | # CIFAR10 21 | dat={} 22 | transform=transforms.Compose([transforms.ToTensor()], transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) 23 | #transforms = None 24 | 25 | dat['train']=datasets.CIFAR10('../dat/',train=True,download=True,transform=transform) 26 | dat['test']=datasets.CIFAR10('../dat/',train=False,download=True,transform=transform) 27 | for n in range(10): 28 | data[n]={} 29 | data[n]['name']='cifar10' 30 | data[n]['ncla']=1 31 | data[n]['train']={'x': [],'y': []} 32 | data[n]['test']={'x': [],'y': []} 33 | for s in ['train','test']: 34 | loader=torch.utils.data.DataLoader(dat[s],batch_size=1,shuffle=False) 35 | for image,target in loader: 36 | n=target.numpy()[0] 37 | nn=n 38 | data[nn][s]['x'].append(image) 39 | data[nn][s]['y'].append(n) 40 | 41 | # "Unify" and save 42 | for t in data.keys(): 43 | for s in ['train','test']: 44 | data[t][s]['x']=torch.stack(data[t][s]['x']).view(-1,size[0],size[1],size[2]) 45 | data[t][s]['y']=torch.LongTensor(np.array(data[t][s]['y'],dtype=int)).view(-1) 46 | torch.save(data[t][s]['x'], os.path.join(os.path.expanduser('../dat/binary_cifar_10'),'data'+str(t)+s+'x.bin')) 47 | torch.save(data[t][s]['y'], os.path.join(os.path.expanduser('../dat/binary_cifar_10'),'data'+str(t)+s+'y.bin')) 48 | 49 | # Load binary files 50 | data={} 51 | ids=np.arange(10) 52 | print('Task order =',ids) 53 | for i in range(10): 54 | data[i] = dict.fromkeys(['name','ncla','train','test']) 55 | for s in ['train','test']: 56 | data[i][s]={'x':[],'y':[]} 57 | data[i][s]['x']=torch.load(os.path.join(os.path.expanduser('../dat/binary_cifar_10'),'data'+str(ids[i])+s+'x.bin')) 58 | data[i][s]['y']=torch.load(os.path.join(os.path.expanduser('../dat/binary_cifar_10'),'data'+str(ids[i])+s+'y.bin')) 59 | data[i]['ncla']=len(np.unique(data[i]['train']['y'].numpy())) 60 | if data[i]['ncla']==1: 61 | data[i]['name']='cifar10-'+str(ids[i]) 62 | else: 63 | data[i]['name']='cifar100-'+str(ids[i]-5) 64 | 65 | # Validation 66 | for t in data.keys(): 67 | r=np.arange(data[t]['train']['x'].size(0)) 68 | r=np.array(shuffle(r,random_state=seed),dtype=int) 69 | nvalid=int(pc_valid*len(r)) 70 | ivalid=torch.LongTensor(r[:nvalid]) 71 | itrain=torch.LongTensor(r[nvalid:]) 72 | data[t]['valid']={} 73 | data[t]['valid']['x']=data[t]['train']['x'][ivalid].clone() 74 | data[t]['valid']['y']=data[t]['train']['y'][ivalid].clone() 75 | data[t]['train']['x']=data[t]['train']['x'][itrain].clone() 76 | data[t]['train']['y']=data[t]['train']['y'][itrain].clone() 77 | 78 | # Others 79 | n=0 80 | for t in data.keys(): 81 | taskcla.append((t,data[t]['ncla'])) 82 | n+=data[t]['ncla'] 83 | data['ncla']=n 84 | 85 | return data,taskcla,size 86 | -------------------------------------------------------------------------------- /dataloaders/split_MNIST.py: -------------------------------------------------------------------------------- 1 | #Copyright 2019 SAP SE 2 | #Licensed under the Apache License, Version 2.0 (the "License"); 3 | #you may not use this file except in compliance with the License. 4 | #You may obtain a copy of the License at 5 | 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | 8 | #Unless required by applicable law or agreed to in writing, software 9 | #distributed under the License is distributed on an "AS IS" BASIS, 10 | #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | #See the License for the specific language governing permissions and 12 | #limitations under the License. 13 | import os,sys 14 | import os.path 15 | import numpy as np 16 | import random 17 | import torch 18 | import torch.utils.data 19 | from torchvision import datasets, transforms 20 | from sklearn.utils import shuffle 21 | import urllib.request 22 | from PIL import Image 23 | import pickle 24 | import utils 25 | 26 | from lib import data_manager 27 | from lib import data_io 28 | 29 | import numpy as np 30 | import keras 31 | from keras.utils import np_utils 32 | 33 | from keras.datasets import mnist 34 | 35 | 36 | def split_dataset_by_labels(X, y, task_labels, nb_classes=None, multihead=False): 37 | """Split dataset by labels. 38 | 39 | Args: 40 | X: data 41 | y: labels 42 | task_labels: list of list of labels, one for each dataset 43 | nb_classes: number of classes (used to convert to one-hot) 44 | Returns: 45 | List of (X, y) tuples representing each dataset 46 | """ 47 | if nb_classes is None: 48 | nb_classes = len(np.unique(y)) 49 | datasets = [] 50 | for labels in task_labels: 51 | idx = np.in1d(y, labels) 52 | if multihead: 53 | label_map = np.arange(nb_classes) 54 | label_map[labels] = np.arange(len(labels)) 55 | data = X[idx], np_utils.to_categorical(label_map[y[idx]], len(labels)) 56 | else: 57 | data = X[idx], np_utils.to_categorical(y[idx], nb_classes) 58 | datasets.append(data) 59 | return datasets 60 | 61 | 62 | def construct_split_mnist(task_labels, split='train', multihead=False): 63 | """Split MNIST dataset by labels. 64 | 65 | Args: 66 | task_labels: list of list of labels, one for each dataset 67 | split: whether to use train or testing data 68 | 69 | Returns: 70 | List of (X, y) tuples representing each dataset 71 | """ 72 | # Load MNIST data and normalize 73 | nb_classes = 10 74 | (X_train, y_train), (X_test, y_test) = mnist.load_data() 75 | X_train = X_train.reshape(-1, 28, 28) 76 | X_test = X_test.reshape(-1, 28, 28) 77 | X_train = X_train.astype('float32') 78 | X_test = X_test.astype('float32') 79 | X_train /= 255 80 | X_test /= 255 81 | 82 | if split == 'train': 83 | X, y = X_train, y_train 84 | else: 85 | X, y = X_test, y_test 86 | 87 | return split_dataset_by_labels(X, y, task_labels, nb_classes, multihead) 88 | 89 | 90 | 91 | def get(seed=0, data_root=None, fixed_order=False, pc_valid=0.15, n_classes=1, imageSize=None): 92 | print("Getting") 93 | binary = False 94 | if n_classes == 1: 95 | binary = True 96 | ncla=1 97 | task_labels = [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]] 98 | 99 | elif n_classes == 2: 100 | ncla=2 101 | task_labels = [[0,1], [2,3], [4,5], [6,7], [8,9] ] 102 | 103 | else: 104 | ncla = 10 105 | task_labels = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]] 106 | 107 | data={} 108 | taskcla=[] 109 | size=imageSize # -size of the maximum input, smaller once will be padded with 0 110 | n_tasks = len(task_labels) 111 | training_datasets = construct_split_mnist(task_labels, split='train', multihead=binary) 112 | validation_datasets = construct_split_mnist(task_labels, split='test', multihead = binary) 113 | 114 | if not os.path.isdir(data_root+'/'): 115 | os.makedirs(data_root+'/') 116 | for n, idx in enumerate(range(n_tasks)): 117 | dat = {} 118 | dat['train'] = Split_MNIST_loader( training_datasets[idx], idx, transform=transforms.Compose([transforms.ToPILImage(), transforms.Resize(imageSize),transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]), binary=binary) 119 | dat['test'] = Split_MNIST_loader( validation_datasets[idx], idx, transform=transforms.Compose([transforms.ToPILImage(), transforms.Resize(imageSize),transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]), binary= binary) 120 | data[n] = {} 121 | data[n]['name'] = str(idx) 122 | data[n]['ncla'] = ncla 123 | 124 | for s in ['train', 'test']: 125 | loader = torch.utils.data.DataLoader(dat[s], batch_size=1, shuffle=False) 126 | data[n][s] = {'x': [], 'y': []} 127 | 128 | for image, target in loader: 129 | data[n][s]['x'].append(image[0]) 130 | data[n][s]['y'].append(target.numpy()[0]) 131 | 132 | for s in ['train', 'test']: 133 | # Expand to 5000 134 | data[n][s]['x'] = torch.stack(data[n][s]['x']) 135 | data[n][s]['y'] = torch.LongTensor(np.array(data[n][s]['y'], dtype=int)).view(-1) 136 | #SAVE 137 | torch.save(data[n][s]['x'], os.path.join(os.path.expanduser(data_root), 'data' + str(idx) + s + 'x.bin')) 138 | torch.save(data[n][s]['y'], os.path.join(os.path.expanduser(data_root), 'data' + str(idx) + s + 'y.bin')) 139 | else: 140 | # Load binary files 141 | for n, idx in enumerate(range(n_tasks)): 142 | data[n] = dict.fromkeys(['name', 'ncla', 'train', 'test']) 143 | data[n]['name'] = str(idx) 144 | data[n]['ncla'] = ncla 145 | 146 | # Load 147 | for s in ['train', 'test']: 148 | data[n][s] = {'x': [], 'y': []} 149 | data[n][s]['x'] = torch.load( 150 | os.path.join(os.path.expanduser(data_root), 'data' + str(idx) + s + 'x.bin')) 151 | data[n][s]['y'] = torch.load( 152 | os.path.join(os.path.expanduser(data_root), 'data' + str(idx) + s + 'y.bin')) 153 | 154 | # Validation 155 | for t in data.keys(): 156 | r = np.arange(data[t]['train']['x'].size(0)) 157 | print("data[t]['train']['x'].size(0)", data[t]['train']['x'].shape) 158 | print(r) 159 | r = np.array(shuffle(r, random_state=seed), dtype=int) 160 | nvalid = int(pc_valid * len(r)) 161 | print(nvalid) 162 | 163 | ivalid = torch.LongTensor(r[:nvalid]) 164 | itrain = torch.LongTensor(r[nvalid:]) 165 | 166 | data[t]['valid'] = {} 167 | data[t]['valid']['x'] = data[t]['train']['x'][ivalid].clone() 168 | data[t]['valid']['y'] = data[t]['train']['y'][ivalid].clone() 169 | data[t]['train']['x'] = data[t]['train']['x'][itrain].clone() 170 | data[t]['train']['y'] = data[t]['train']['y'][itrain].clone() 171 | 172 | # Others 173 | n = 0 174 | for t in data.keys(): 175 | taskcla.append((t, data[t]['ncla'])) 176 | n += data[t]['ncla'] 177 | #data['ncla'] = n 178 | 179 | return data, taskcla, size 180 | 181 | class Split_MNIST_loader(torch.utils.data.Dataset): 182 | 183 | def __init__(self, data, task, train=True, transform=None, max_samples=float('Inf'), seed=0, binary=True): 184 | self.transform = transform 185 | random.seed(seed) 186 | self.data = data[0] 187 | self.labels = data[1] 188 | if not binary: 189 | if len(self.labels.shape) > 1: 190 | self.labels = np.where(self.labels == 1)[1] #- 2*task 191 | 192 | def __getitem__(self, index): 193 | """ 194 | Args: index (int): Index 195 | Returns: tuple: (image, target) where target is index of the target class. 196 | """ 197 | img, target = self.data[index], self.labels[index] 198 | # doing this so that it is consistent with all other datasets 199 | # to return a PIL Image 200 | img = img.reshape(1, 28, 28) 201 | img = torch.Tensor(img) 202 | if self.transform is not None: 203 | img = self.transform(img) 204 | return img, target 205 | 206 | def __len__(self): 207 | return len(self.data) -------------------------------------------------------------------------------- /dataloaders/split_SVHN.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 SAP SE 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | import os 14 | import numpy as np 15 | import torch 16 | import copy 17 | import torchvision.datasets as dset 18 | import torchvision.transforms as transforms 19 | 20 | 21 | def load_data(dataset_train, dataset_test, classes_test, classes_train): 22 | dataset_train_ = copy.deepcopy(dataset_train) 23 | dataset_test_ = copy.deepcopy(dataset_test) 24 | idx = np.nonzero(np.isin(dataset_train.labels, classes_train))[0] 25 | dataset_train_.labels = np.array(dataset_train.labels).squeeze()[idx] 26 | dataset_train_.data = dataset_train.data[idx] 27 | 28 | idx_test = np.nonzero(np.isin(dataset_test.labels, classes_test))[0] 29 | dataset_test_.labels = np.array(dataset_test.labels).squeeze()[idx_test] 30 | dataset_test_.data = dataset_test.data[idx_test] 31 | 32 | return dataset_train_, dataset_test_ 33 | 34 | 35 | def get(seed=0, data_root=None, fixed_order=False, pc_valid=0.1, n_classes=1, imageSize=None): 36 | ncla = n_classes 37 | size = imageSize 38 | data = {} 39 | if not os.path.isdir(data_root + '/split_svhn_' + str(size) + '/'): 40 | os.makedirs(data_root + '/split_svhn_' + str(size) + '/') 41 | dataset_train = dset.SVHN(root=data_root, split='train',download=True, 42 | transform=transforms.Compose([ 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.5, 0.5, 0.5), 45 | (0.5, 0.5, 0.5)), 46 | ])) 47 | dataset_test = dset.SVHN(root=data_root, split='test',download=True, 48 | transform=transforms.Compose([ 49 | #transforms.Resize(imageSize), 50 | #transforms.CenterCrop(imageSize), 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.5, 0.5, 0.5), 53 | (0.5, 0.5, 0.5)), 54 | ])) 55 | for n in range(10): 56 | print("Loading data task: ", n) 57 | data[n] = {} 58 | data[n]['name'] = str(n) 59 | data[n]['ncla'] = ncla 60 | dataset_train_, dataset_test_ = load_data(dataset_train,dataset_test, [n], [n]) 61 | train_length = int((1 - pc_valid) * len(dataset_train_.labels)) 62 | valid_length = len(dataset_train_) - train_length 63 | dataset_train__, valid_subset = torch.utils.data.random_split(dataset_train_, (train_length, valid_length)) 64 | loader_train = torch.utils.data.DataLoader(dataset_train__, batch_size=len(dataset_train__), shuffle=False) 65 | loader_test = torch.utils.data.DataLoader(dataset_test_, batch_size=len(dataset_test_), shuffle=False) 66 | loader_valid = torch.utils.data.DataLoader(valid_subset, batch_size=len(valid_subset), shuffle=False) 67 | x_train, y_train = next(iter(loader_train)) 68 | x_test, y_test = next(iter(loader_test)) 69 | x_valid, y_valid = next(iter(loader_valid)) 70 | data[n]['train'] = {'x': x_train, 'y': y_train} 71 | data[n]['valid'] = {'x': x_valid, 'y': y_valid} 72 | data[n]['test'] = {'x': x_test, 'y': y_test} 73 | 74 | for s in ['train', 'test', 'valid']: 75 | #data[n][s]['x'] = torch.stack(data[n][s]['x']) 76 | data[n][s]['y'] = torch.LongTensor(np.array(data[n][s]['y'], dtype=int)).view(-1) 77 | print(data[n][s]['y']) 78 | # SAVE 79 | torch.save(data[n][s]['x'], os.path.join(os.path.expanduser(data_root + '/split_svhn_' + str(size)), 80 | 'data' + str(n) + s + 'x.bin')) 81 | torch.save(data[n][s]['y'], os.path.join(os.path.expanduser(data_root + '/split_svhn_' + str(size)), 82 | 'data' + str(n) + s + 'y.bin')) 83 | 84 | else: 85 | # Load binary files 86 | for idx in range(10): 87 | data[idx] = dict.fromkeys(['name', 'ncla', 'train', 'test']) 88 | data[idx]['name'] = str(idx) 89 | data[idx]['ncla'] = ncla 90 | 91 | # Load 92 | for s in ['train', 'test', 'valid']: 93 | data[idx][s] = {'x': [], 'y': []} 94 | data[idx][s]['x'] = torch.load( 95 | os.path.join(os.path.expanduser(data_root + '/split_svhn_' + str(size)), 96 | 'data' + str(idx) + s + 'x.bin')) 97 | # data[idx][s]['x'] = torch.LongTensor(data[idx][s]['x']) 98 | data[idx][s]['y'] = torch.load( 99 | os.path.join(os.path.expanduser(data_root + '/split_svhn_' + str(size)), 100 | 'data' + str(idx) + s + 'y.bin')) 101 | # data[idx][s]['y'] = torch.LongTensor(data[idx][s]['y']) 102 | data_test = data 103 | return data, None, None -------------------------------------------------------------------------------- /lib/data_converter.py: -------------------------------------------------------------------------------- 1 | # Functions performing various data conversions for the ChaLearn AutoML challenge 2 | 3 | # Main contributors: Arthur Pesah and Isabelle Guyon, August-October 2014 4 | 5 | # ALL INFORMATION, SOFTWARE, DOCUMENTATION, AND DATA ARE PROVIDED "AS-IS". 6 | # ISABELLE GUYON, CHALEARN, AND/OR OTHER ORGANIZERS OR CODE AUTHORS DISCLAIM 7 | # ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 8 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE, AND THE 9 | # WARRANTY OF NON-INFRIGEMENT OF ANY THIRD PARTY'S INTELLECTUAL PROPERTY RIGHTS. 10 | # IN NO EVENT SHALL ISABELLE GUYON AND/OR OTHER ORGANIZERS BE LIABLE FOR ANY SPECIAL, 11 | # INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER ARISING OUT OF OR IN 12 | # CONNECTION WITH THE USE OR PERFORMANCE OF SOFTWARE, DOCUMENTS, MATERIALS, 13 | # PUBLICATIONS, OR INFORMATION MADE AVAILABLE FOR THE CHALLENGE. 14 | 15 | import numpy as np 16 | from scipy.sparse import * 17 | 18 | # Note: to check for nan values np.any(map(np.isnan,X_train)) 19 | 20 | def file_to_array (filename, verbose=False): 21 | ''' Converts a file to a list of list of STRING 22 | It differs from np.genfromtxt in that the number of columns doesn't need to be constant''' 23 | data =[] 24 | with open(filename, "r") as data_file: 25 | if verbose: print ("Reading {}...".format(filename)) 26 | lines = data_file.readlines() 27 | if verbose: print ("Converting {} to correct array...".format(filename)) 28 | data = [lines[i].strip().split() for i in range (len(lines))] 29 | return data 30 | 31 | def read_first_line (filename): 32 | ''' Read fist line of file''' 33 | data =[] 34 | with open(filename, "r") as data_file: 35 | line = data_file.readline() 36 | data = line.strip().split() 37 | return data 38 | 39 | def num_lines (filename): 40 | ''' Count the number of lines of file''' 41 | return sum(1 for line in open(filename)) 42 | 43 | def binarization (array): 44 | ''' Takes a binary-class datafile and turn the max value (positive class) into 1 and the min into 0''' 45 | array = np.array(array, dtype=float) # conversion needed to use np.inf after 46 | if len(np.unique(array)) > 2: 47 | raise ValueError ("The argument must be a binary-class datafile. {} classes detected".format(len(np.unique(array)))) 48 | 49 | # manipulation which aims at avoid error in data with for example classes '1' and '2'. 50 | array[array == np.amax(array)] = np.inf 51 | array[array == np.amin(array)] = 0 52 | array[array == np.inf] = 1 53 | return np.array(array, dtype=int) 54 | 55 | def sparse_file_to_sparse_list (filename, verbose=True): 56 | ''' Converts a sparse data file to a sparse list, so that : 57 | sparse_list[i][j] = (a,b) means matrix[i][a]=b''' 58 | data_file = open(filename, "r") 59 | if verbose: print ("Reading {}...".format(filename)) 60 | lines = data_file.readlines() 61 | if verbose: print ("Converting {} to correct array") 62 | data = [lines[i].split(' ') for i in range (len(lines))] 63 | if verbose: print ("Converting {} to sparse list".format (filename)) 64 | return [[tuple(map(int, data[i][j].rstrip().split(':'))) for j in range(len(data[i])) if data[i][j] != '\n'] for i in range (len(data))] 65 | 66 | def sparse_list_to_csr_sparse (sparse_list, nbr_features, verbose=True): 67 | ''' This function takes as argument a matrix of tuple representing a sparse matrix and the number of features. 68 | sparse_list[i][j] = (a,b) means matrix[i][a]=b 69 | It converts it into a scipy csr sparse matrix''' 70 | nbr_samples = len(sparse_list) 71 | dok_sparse = dok_matrix ((nbr_samples, nbr_features)) # construction easier w/ dok_sparse... 72 | if verbose: print ("\tConverting sparse list to dok sparse matrix") 73 | for row in range (nbr_samples): 74 | for column in range (len(sparse_list[row])): 75 | (feature,value) = sparse_list[row][column] 76 | dok_sparse[row, feature-1] = value 77 | if verbose: print ("\tConverting dok sparse matrix to csr sparse matrix") 78 | return dok_sparse.tocsr() # ... but csr better for shuffling data or other tricks 79 | 80 | def multilabel_to_multiclass (array): 81 | array = binarization (array) 82 | return np.array([np.nonzero(array[i,:])[0][0] for i in range (len(array))]) 83 | 84 | def convert_to_num(Ybin, verbose=True): 85 | ''' Convert binary targets to numeric vector (typically classification target values)''' 86 | if verbose: print("\tConverting to numeric vector") 87 | Ybin = np.array(Ybin) 88 | if len(Ybin.shape) ==1: 89 | return Ybin 90 | classid=range(Ybin.shape[1]) 91 | Ycont = np.dot(Ybin, classid) 92 | if verbose: print(Ycont) 93 | return Ycont 94 | 95 | def convert_to_bin(Ycont, nval, verbose=True): 96 | ''' Convert numeric vector to binary (typically classification target values)''' 97 | if verbose: print ("\t_______ Converting to binary representation") 98 | Ybin=[[0]*nval for x in xrange(len(Ycont))] 99 | for i in range(len(Ybin)): 100 | line = Ybin[i] 101 | line[np.int(Ycont[i])]=1 102 | Ybin[i] = line 103 | return Ybin 104 | 105 | def tp_filter(X, Y, feat_num=1000, verbose=True): 106 | ''' TP feature selection in the spirit of the winners of the KDD cup 2001 107 | Only for binary classification and sparse matrices''' 108 | 109 | if issparse(X) and len(Y.shape)==1 and len(set(Y))==2 and (sum(Y)/Y.shape[0])<0.1: 110 | if verbose: print("========= Filtering features...") 111 | Posidx=Y>0 112 | #npos = sum(Posidx) 113 | #Negidx=Y<=0 114 | #nneg = sum(Negidx) 115 | 116 | nz=X.nonzero() 117 | mx=X[nz].max() 118 | if X[nz].min()==mx: # sparse binary 119 | if mx!=1: X[nz]=1 120 | tp=csr_matrix.sum(X[Posidx,:], axis=0) 121 | #fn=npos-tp 122 | #fp=csr_matrix.sum(X[Negidx,:], axis=0) 123 | #tn=nneg-fp 124 | else: 125 | tp=np.sum(X[Posidx,:]>0, axis=0) 126 | #tn=np.sum(X[Negidx,:]<=0, axis=0) 127 | #fn=np.sum(X[Posidx,:]<=0, axis=0) 128 | #fp=np.sum(X[Negidx,:]>0, axis=0) 129 | 130 | tp=np.ravel(tp) 131 | idx=sorted(range(len(tp)), key=tp.__getitem__, reverse=True) 132 | return idx[0:feat_num] 133 | else: 134 | feat_num = X.shape[1] 135 | return range(feat_num) 136 | 137 | def replace_missing(X): 138 | # This is ugly, but 139 | try: 140 | if X.getformat()=='csr': 141 | return X 142 | except: 143 | XX = np.nan_to_num(X) 144 | # p=len(X) 145 | # nn=len(X[0])*2 146 | # XX = np.zeros([p,nn]) 147 | # for i in range(len(X)): 148 | # line = X[i] 149 | # line1 = [0 if np.isnan(x) else x for x in line] 150 | # line2 = [1 if np.isnan(x) else 0 for x in line] # indicator of missingness 151 | # XX[i] = line1 + line2 152 | return XX 153 | -------------------------------------------------------------------------------- /lib/data_io.py: -------------------------------------------------------------------------------- 1 | # Functions performing various input/output operations for the ChaLearn AutoML challenge 2 | 3 | # Main contributors: Arthur Pesah and Isabelle Guyon, August-October 2014 4 | 5 | # ALL INFORMATION, SOFTWARE, DOCUMENTATION, AND DATA ARE PROVIDED "AS-IS". 6 | # ISABELLE GUYON, CHALEARN, AND/OR OTHER ORGANIZERS OR CODE AUTHORS DISCLAIM 7 | # ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 8 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE, AND THE 9 | # WARRANTY OF NON-INFRIGEMENT OF ANY THIRD PARTY'S INTELLECTUAL PROPERTY RIGHTS. 10 | # IN NO EVENT SHALL ISABELLE GUYON AND/OR OTHER ORGANIZERS BE LIABLE FOR ANY SPECIAL, 11 | # INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER ARISING OUT OF OR IN 12 | # CONNECTION WITH THE USE OR PERFORMANCE OF SOFTWARE, DOCUMENTS, MATERIALS, 13 | # PUBLICATIONS, OR INFORMATION MADE AVAILABLE FOR THE CHALLENGE. 14 | 15 | from __future__ import print_function 16 | from sys import getsizeof, stderr 17 | from itertools import chain 18 | from collections import deque 19 | try: 20 | from reprlib import repr 21 | except ImportError: 22 | pass 23 | 24 | import numpy as np 25 | import os 26 | import shutil 27 | from scipy.sparse import * # used in data_binary_sparse 28 | from zipfile import ZipFile, ZIP_DEFLATED 29 | from contextlib import closing 30 | from lib import data_converter 31 | from sys import stderr 32 | from sys import version 33 | from glob import glob as ls 34 | from os import getcwd as pwd 35 | #from pip import get_installed_distributions as lib 36 | import yaml 37 | from shutil import copy2 38 | 39 | # ================ Small auxiliary functions ================= 40 | 41 | swrite = stderr.write 42 | 43 | if (os.name == "nt"): 44 | filesep = '\\' 45 | else: 46 | filesep = '/' 47 | 48 | def write_list(lst): 49 | ''' Write a list of items to stderr (for debug purposes)''' 50 | for item in lst: 51 | swrite(item + "\n") 52 | 53 | def print_dict(verbose, dct): 54 | ''' Write a dict to stderr (for debug purposes)''' 55 | if verbose: 56 | for item in dct: 57 | print(item + " = " + str(dct[item])) 58 | 59 | def mkdir(d): 60 | ''' Create a new directory''' 61 | if not os.path.exists(d): 62 | os.makedirs(d) 63 | 64 | def mvdir(source, dest): 65 | ''' Move a directory''' 66 | if os.path.exists(source): 67 | os.rename(source, dest) 68 | 69 | def rmdir(d): 70 | ''' Remove an existingdirectory''' 71 | if os.path.exists(d): 72 | shutil.rmtree(d) 73 | 74 | def vprint(mode, t): 75 | ''' Print to stdout, only if in verbose mode''' 76 | if(mode): 77 | print(t) 78 | 79 | # ================ Output prediction results and prepare code submission ================= 80 | 81 | def write(filename, predictions): 82 | ''' Write prediction scores in prescribed format''' 83 | with open(filename, "w") as output_file: 84 | for row in predictions: 85 | if type(row) is not np.ndarray and type(row) is not list: 86 | row = [row] 87 | for val in row: 88 | output_file.write('{0:g} '.format(float(val))) 89 | output_file.write('\n') 90 | 91 | def zipdir(archivename, basedir): 92 | '''Zip directory, from J.F. Sebastian http://stackoverflow.com/''' 93 | assert os.path.isdir(basedir) 94 | with closing(ZipFile(archivename, "w", ZIP_DEFLATED)) as z: 95 | for root, dirs, files in os.walk(basedir): 96 | #NOTE: ignore empty directories 97 | for fn in files: 98 | if fn[-4:]!='.zip': 99 | absfn = os.path.join(root, fn) 100 | zfn = absfn[len(basedir)+len(os.sep):] #XXX: relative path 101 | z.write(absfn, zfn) 102 | 103 | # ================ Inventory input data and create data structure ================= 104 | 105 | def inventory_data(input_dir): 106 | ''' Inventory the datasets in the input directory and return them in alphabetical order''' 107 | # Assume first that there is a hierarchy dataname/dataname_train.data 108 | training_names = inventory_data_dir(input_dir) 109 | ntr=len(training_names) 110 | if ntr==0: 111 | # Try to see if there is a flat directory structure 112 | training_names = inventory_data_nodir(input_dir) 113 | ntr=len(training_names) 114 | if ntr==0: 115 | print('WARNING: Inventory data - No data file found') 116 | training_names = [] 117 | training_names.sort() 118 | return training_names 119 | 120 | def inventory_data_nodir(input_dir): 121 | ''' Inventory data, assuming flat directory structure''' 122 | training_names = ls(os.path.join(input_dir, '*_train.data')) 123 | for i in range(0,len(training_names)): 124 | name = training_names[i] 125 | training_names[i] = name[-name[::-1].index(filesep):-name[::-1].index('_')-1] 126 | check_dataset(input_dir, training_names[i]) 127 | return training_names 128 | 129 | def inventory_data_dir(input_dir): 130 | ''' Inventory data, assuming flat directory structure, assuming a directory hierarchy''' 131 | training_names = ls(input_dir + '/*/*_train.data') # This supports subdirectory structures obtained by concatenating bundles 132 | for i in range(0,len(training_names)): 133 | name = training_names[i] 134 | training_names[i] = name[-name[::-1].index(filesep):-name[::-1].index('_')-1] 135 | check_dataset(os.path.join(input_dir, training_names[i]), training_names[i]) 136 | return training_names 137 | 138 | def check_dataset(dirname, name): 139 | ''' Check the test and valid files are in the directory, as well as the solution''' 140 | valid_file = os.path.join(dirname, name + '_valid.data') 141 | if not os.path.isfile(valid_file): 142 | print('No validation file for ' + name) 143 | exit(1) 144 | test_file = os.path.join(dirname, name + '_test.data') 145 | if not os.path.isfile(test_file): 146 | print('No test file for ' + name) 147 | exit(1) 148 | # Check the training labels are there 149 | training_solution = os.path.join(dirname, name + '_train.solution') 150 | if not os.path.isfile(training_solution): 151 | print('No training labels for ' + name) 152 | exit(1) 153 | return True 154 | 155 | def data(filename, nbr_features=None, verbose = False): 156 | ''' The 2nd parameter makes possible a using of the 3 functions of data reading (data, data_sparse, data_binary_sparse) without changing parameters''' 157 | if verbose: print (np.array(data_converter.file_to_array(filename))) 158 | return np.array(data_converter.file_to_array(filename), dtype=float) 159 | 160 | def data_sparse (filename, nbr_features): 161 | ''' This function takes as argument a file representing a sparse matrix 162 | sparse_matrix[i][j] = "a:b" means matrix[i][a] = b 163 | It converts it into a numpy array, using sparse_list_to_array function, and returns this array''' 164 | sparse_list = data_converter.sparse_file_to_sparse_list(filename) 165 | return data_converter.sparse_list_to_csr_sparse (sparse_list, nbr_features) 166 | #return data_converter.sparse_list_to_array (sparse_list, nbr_features) 167 | 168 | def data_binary_sparse (filename, nbr_features): 169 | ''' This function takes as an argument a file representing a binary sparse matrix 170 | binary_sparse_matrix[i][j] = a means matrix[i][j] = 1 171 | It converts it into a numpy array an returns this array. ''' 172 | 173 | data = data_converter.file_to_array (filename) 174 | nbr_samples = len(data) 175 | dok_sparse = dok_matrix ((nbr_samples, nbr_features)) # the construction is easier w/ dok_sparse 176 | print ("Converting {} to dok sparse matrix".format(filename)) 177 | for row in range (nbr_samples): 178 | for feature in data[row]: 179 | dok_sparse[row, int(feature)-1] = 1 180 | print ("Converting {} to csr sparse matrix".format(filename)) 181 | return dok_sparse.tocsr() 182 | 183 | # ================ Copy results from input to output ========================== 184 | 185 | def copy_results(datanames, result_dir, output_dir, verbose): 186 | ''' This function copies all the [dataname.predict] results from result_dir to output_dir''' 187 | missing_files = [] 188 | for basename in datanames: 189 | try: 190 | missing = False 191 | test_files = ls(result_dir + "/" + basename + "*_test*.predict") 192 | if len(test_files)==0: 193 | vprint(verbose, "[-] Missing 'test' result files for " + basename) 194 | missing = True 195 | valid_files = ls(result_dir + "/" + basename + "*_valid*.predict") 196 | if len(valid_files)==0: 197 | vprint(verbose, "[-] Missing 'valid' result files for " + basename) 198 | missing = True 199 | if missing == False: 200 | for f in test_files: copy2(f, output_dir) 201 | for f in valid_files: copy2(f, output_dir) 202 | vprint( verbose, "[+] " + basename.capitalize() + " copied") 203 | else: 204 | missing_files.append(basename) 205 | except: 206 | vprint(verbose, "[-] Missing result files") 207 | return datanames 208 | return missing_files 209 | 210 | # ================ Display directory structure and code version (for debug purposes) ================= 211 | 212 | def show_dir(run_dir): 213 | print('\n=== Listing run dir ===') 214 | write_list(ls(run_dir)) 215 | write_list(ls(run_dir + '/*')) 216 | write_list(ls(run_dir + '/*/*')) 217 | write_list(ls(run_dir + '/*/*/*')) 218 | write_list(ls(run_dir + '/*/*/*/*')) 219 | 220 | def show_io(input_dir, output_dir): 221 | swrite('\n=== DIRECTORIES ===\n\n') 222 | # Show this directory 223 | swrite("-- Current directory " + pwd() + ":\n") 224 | write_list(ls('.')) 225 | write_list(ls('./*')) 226 | write_list(ls('./*/*')) 227 | swrite("\n") 228 | 229 | # List input and output directories 230 | swrite("-- Input directory " + input_dir + ":\n") 231 | write_list(ls(input_dir)) 232 | write_list(ls(input_dir + '/*')) 233 | write_list(ls(input_dir + '/*/*')) 234 | write_list(ls(input_dir + '/*/*/*')) 235 | swrite("\n") 236 | swrite("-- Output directory " + output_dir + ":\n") 237 | write_list(ls(output_dir)) 238 | write_list(ls(output_dir + '/*')) 239 | swrite("\n") 240 | 241 | # write meta data to sdterr 242 | swrite('\n=== METADATA ===\n\n') 243 | swrite("-- Current directory " + pwd() + ":\n") 244 | try: 245 | metadata = yaml.load(open('metadata', 'r')) 246 | for key,value in metadata.items(): 247 | swrite(key + ': ') 248 | swrite(str(value) + '\n') 249 | except: 250 | swrite("none\n"); 251 | swrite("-- Input directory " + input_dir + ":\n") 252 | try: 253 | metadata = yaml.load(open(os.path.join(input_dir, 'metadata'), 'r')) 254 | for key,value in metadata.items(): 255 | swrite(key + ': ') 256 | swrite(str(value) + '\n') 257 | swrite("\n") 258 | except: 259 | swrite("none\n"); 260 | 261 | def show_version(): 262 | # Python version and library versions 263 | swrite('\n=== VERSIONS ===\n\n') 264 | # Python version 265 | swrite("Python version: " + version + "\n\n") 266 | # Give information on the version installed 267 | swrite("Versions of libraries installed:\n") 268 | #map(swrite, sorted(["%s==%s\n" % (i.key, i.version) for i in lib()])) 269 | 270 | # Compute the total memory size of an object in bytes 271 | 272 | def total_size(o, handlers={}, verbose=False): 273 | """ Returns the approximate memory footprint an object and all of its contents. 274 | 275 | Automatically finds the contents of the following builtin containers and 276 | their subclasses: tuple, list, deque, dict, set and frozenset. 277 | To search other containers, add handlers to iterate over their contents: 278 | 279 | handlers = {SomeContainerClass: iter, 280 | OtherContainerClass: OtherContainerClass.get_elements} 281 | 282 | """ 283 | dict_handler = lambda d: chain.from_iterable(d.items()) 284 | all_handlers = {tuple: iter, 285 | list: iter, 286 | deque: iter, 287 | dict: dict_handler, 288 | set: iter, 289 | frozenset: iter, 290 | } 291 | all_handlers.update(handlers) # user handlers take precedence 292 | seen = set() # track which object id's have already been seen 293 | default_size = getsizeof(0) # estimate sizeof object without __sizeof__ 294 | 295 | def sizeof(o): 296 | if id(o) in seen: # do not double count the same object 297 | return 0 298 | seen.add(id(o)) 299 | s = getsizeof(o, default_size) 300 | 301 | if verbose: 302 | print(s, type(o), repr(o), file=stderr) 303 | 304 | for typ, handler in all_handlers.items(): 305 | if isinstance(o, typ): 306 | s += sum(map(sizeof, handler(o))) 307 | break 308 | return s 309 | 310 | return sizeof(o) 311 | 312 | -------------------------------------------------------------------------------- /lib/data_manager.py: -------------------------------------------------------------------------------- 1 | # Functions performing various input/output operations for the ChaLearn AutoML challenge 2 | 3 | # Main contributor: Arthur Pesah, August 2014 4 | # Edits: Isabelle Guyon, October 2014 5 | 6 | # ALL INFORMATION, SOFTWARE, DOCUMENTATION, AND DATA ARE PROVIDED "AS-IS". 7 | # ISABELLE GUYON, CHALEARN, AND/OR OTHER ORGANIZERS OR CODE AUTHORS DISCLAIM 8 | # ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 9 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE, AND THE 10 | # WARRANTY OF NON-INFRIGEMENT OF ANY THIRD PARTY'S INTELLECTUAL PROPERTY RIGHTS. 11 | # IN NO EVENT SHALL ISABELLE GUYON AND/OR OTHER ORGANIZERS BE LIABLE FOR ANY SPECIAL, 12 | # INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER ARISING OUT OF OR IN 13 | # CONNECTION WITH THE USE OR PERFORMANCE OF SOFTWARE, DOCUMENTS, MATERIALS, 14 | # PUBLICATIONS, OR INFORMATION MADE AVAILABLE FOR THE CHALLENGE. 15 | 16 | from lib import data_converter 17 | from lib import data_io 18 | vprint= data_io.vprint 19 | import numpy as np 20 | try: 21 | import cPickle as pickle 22 | except: 23 | import pickle 24 | import os 25 | import time 26 | 27 | class DataManager: 28 | ''' This class aims at loading and saving data easily with a cache and at generating a dictionary (self.info) in which each key is a feature (e.g. : name, format, feat_num,...). 29 | Methods defined here are : 30 | __init__ (...) 31 | x.__init__([(feature, value)]) -> void 32 | Initialize the info dictionary with the tuples (feature, value) given as argument. It recognizes the type of value (int, string) and assign value to info[feature]. An unlimited number of tuple can be sent. 33 | 34 | getInfo (...) 35 | x.getInfo (filename) -> void 36 | Fill the dictionary with an info file. Each line of the info file must have this format 'feature' : value 37 | The information is obtained from the public.info file if it exists, or inferred from the data files 38 | 39 | getInfoFromFile (...) 40 | x.getInfoFromFile (filename) -> void 41 | Fill the dictionary with an info file. Each line of the info file must have this format 'feature' : value 42 | 43 | getFormatData (...) 44 | x.getFormatData (filename) -> str 45 | Get the format of the file ('dense', 'sparse' or 'sparse_binary') either using the 'is_sparse' feature if it exists (for example after a call of getInfoFromFile function) and then determing if it's binary or not, or determining it alone. 46 | 47 | getNbrFeatures (...) 48 | x.getNbrFeatures (*filenames) -> int 49 | Get the number of features, using the data files given. It first checks the format of the data. If it's a matrix, the number of features is trivial. If it's a sparse file, it gets the max feature index given in every files. 50 | 51 | getTypeProblem (...) 52 | x.getTypeProblem (filename) -> str 53 | Get the kind of problem ('binary.classification', 'multiclass.classification', 'multilabel.classification', 'regression'), using the solution file given. 54 | ''' 55 | 56 | def __init__(self, basename="", input_dir="", verbose=False, replace_missing=True, filter_features=False, max_samples=float('inf')): 57 | '''Constructor''' 58 | self.use_pickle = False # Turn this to true to save data as pickle (inefficient) 59 | self.basename = basename 60 | if basename in input_dir: 61 | self.input_dir = input_dir 62 | else: 63 | self.input_dir = os.path.join (input_dir , basename ) 64 | if self.use_pickle: 65 | if os.path.exists ("tmp"): 66 | self.tmp_dir = "tmp" 67 | elif os.path.exists ("../tmp"): 68 | self.tmp_dir = "../tmp" 69 | else: 70 | os.makedirs("tmp") 71 | self.tmp_dir = "tmp" 72 | info_file = os.path.join (self.input_dir, basename + '_public.info') 73 | self.info = {} 74 | self.getInfo (info_file) 75 | self.feat_type = self.loadType (os.path.join(self.input_dir, basename + '_feat.type'), verbose=verbose) 76 | self.data = {} 77 | #if True: return 78 | Xtr = self.loadData (os.path.join(self.input_dir, basename + '_train.data'), verbose=verbose, replace_missing=replace_missing) 79 | Ytr = self.loadLabel (os.path.join(self.input_dir, basename + '_train.solution'), verbose=verbose) 80 | max_samples = min(Xtr.shape[0], max_samples) 81 | Xtr = Xtr[0:max_samples] 82 | Ytr = Ytr[0:max_samples] 83 | Xva = self.loadData (os.path.join(self.input_dir, basename + '_valid.data'), verbose=verbose, replace_missing=replace_missing) 84 | Xte = self.loadData (os.path.join(self.input_dir, basename + '_test.data'), verbose=verbose, replace_missing=replace_missing) 85 | # Normally, feature selection should be done as part of a pipeline. 86 | # However, here we do it as a preprocessing for efficiency reason 87 | idx=[] 88 | if filter_features: # add hoc feature selection, for the example... 89 | fn = min(Xtr.shape[1], 1000) 90 | idx = data_converter.tp_filter(Xtr, Ytr, feat_num=fn, verbose=verbose) 91 | Xtr = Xtr[:,idx] 92 | Xva = Xva[:,idx] 93 | Xte = Xte[:,idx] 94 | self.feat_idx = np.array(idx).ravel() 95 | self.data['X_train'] = Xtr 96 | self.data['Y_train'] = Ytr 97 | self.data['X_valid'] = Xva 98 | self.data['X_test'] = Xte 99 | 100 | def __repr__(self): 101 | return "DataManager : " + self.basename 102 | 103 | def __str__(self): 104 | val = "DataManager : " + self.basename + "\ninfo:\n" 105 | for item in self.info: 106 | val = val + "\t" + item + " = " + str(self.info[item]) + "\n" 107 | val = val + "data:\n" 108 | val = val + "\tX_train = array" + str(self.data['X_train'].shape) + "\n" 109 | val = val + "\tY_train = array" + str(self.data['Y_train'].shape) + "\n" 110 | val = val + "\tX_valid = array" + str(self.data['X_valid'].shape) + "\n" 111 | val = val + "\tX_test = array" + str(self.data['X_test'].shape) + "\n" 112 | val = val + "feat_type:\tarray" + str(self.feat_type.shape) + "\n" 113 | val = val + "feat_idx:\tarray" + str(self.feat_idx.shape) + "\n" 114 | return val 115 | 116 | def loadData (self, filename, verbose=True, replace_missing=True): 117 | ''' Get the data from a text file in one of 3 formats: matrix, sparse, sparse_binary''' 118 | if verbose: print("========= Reading " + filename) 119 | start = time.time() 120 | if self.use_pickle and os.path.exists (os.path.join (self.tmp_dir, os.path.basename(filename) + ".pickle")): 121 | with open (os.path.join (self.tmp_dir, os.path.basename(filename) + ".pickle"), "r") as pickle_file: 122 | vprint (verbose, "Loading pickle file : " + os.path.join(self.tmp_dir, os.path.basename(filename) + ".pickle")) 123 | return pickle.load(pickle_file) 124 | if 'format' not in self.info.keys(): 125 | self.getFormatData(filename) 126 | if 'feat_num' not in self.info.keys(): 127 | self.getNbrFeatures(filename) 128 | 129 | data_func = {'dense':data_io.data, 'sparse':data_io.data_sparse, 'sparse_binary':data_io.data_binary_sparse} 130 | 131 | data = data_func[self.info['format']](filename, self.info['feat_num']) 132 | 133 | # INPORTANT: when we replace missing values we double the number of variables 134 | 135 | if self.info['format']=='dense' and replace_missing and np.any(map(np.isnan,data)): 136 | vprint (verbose, "Replace missing values by 0 (slow, sorry)") 137 | data = data_converter.replace_missing(data) 138 | if self.use_pickle: 139 | with open (os.path.join (self.tmp_dir, os.path.basename(filename) + ".pickle"), "wb") as pickle_file: 140 | vprint (verbose, "Saving pickle file : " + os.path.join (self.tmp_dir, os.path.basename(filename) + ".pickle")) 141 | p = pickle.Pickler(pickle_file) 142 | p.fast = True 143 | p.dump(data) 144 | end = time.time() 145 | if verbose: print( "[+] Success in %5.2f sec" % (end - start)) 146 | return data 147 | 148 | def loadLabel (self, filename, verbose=True): 149 | ''' Get the solution/truth values''' 150 | if verbose: print("========= Reading " + filename) 151 | start = time.time() 152 | if self.use_pickle and os.path.exists (os.path.join (self.tmp_dir, os.path.basename(filename) + ".pickle")): 153 | with open (os.path.join (self.tmp_dir, os.path.basename(filename) + ".pickle"), "r") as pickle_file: 154 | vprint (verbose, "Loading pickle file : " + os.path.join (self.tmp_dir, os.path.basename(filename) + ".pickle")) 155 | return pickle.load(pickle_file) 156 | if 'task' not in self.info.keys(): 157 | self.getTypeProblem(filename) 158 | 159 | # IG: Here change to accommodate the new multiclass label format 160 | if self.info['task'] == 'multilabel.classification': 161 | label = data_io.data(filename) 162 | elif self.info['task'] == 'multiclass.classification': 163 | label = data_converter.convert_to_num(data_io.data(filename)) 164 | else: 165 | label = np.ravel(data_io.data(filename)) # get a column vector 166 | #label = np.array([np.ravel(data_io.data(filename))]).transpose() # get a column vector 167 | 168 | if self.use_pickle: 169 | with open (os.path.join (self.tmp_dir, os.path.basename(filename) + ".pickle"), "wb") as pickle_file: 170 | vprint (verbose, "Saving pickle file : " + os.path.join (self.tmp_dir, os.path.basename(filename) + ".pickle")) 171 | p = pickle.Pickler(pickle_file) 172 | p.fast = True 173 | p.dump(label) 174 | end = time.time() 175 | if verbose: print( "[+] Success in %5.2f sec" % (end - start)) 176 | return label 177 | 178 | def loadType (self, filename, verbose=True): 179 | ''' Get the variable types''' 180 | if verbose: print("========= Reading " + filename) 181 | start = time.time() 182 | type_list = [] 183 | if os.path.isfile(filename): 184 | type_list = data_converter.file_to_array (filename, verbose=False) 185 | else: 186 | n=self.info['feat_num'] 187 | type_list = [self.info['feat_type']]*n 188 | type_list = np.array(type_list).ravel() 189 | end = time.time() 190 | if verbose: print( "[+] Success in %5.2f sec" % (end - start)) 191 | return type_list 192 | 193 | def getInfo (self, filename, verbose=True): 194 | ''' Get all information {attribute = value} pairs from the filename (public.info file), 195 | if it exists, otherwise, output default values''' 196 | if filename==None: 197 | basename = self.basename 198 | input_dir = self.input_dir 199 | else: 200 | basename = os.path.basename(filename).split('_')[0] 201 | input_dir = os.path.dirname(filename) 202 | if os.path.exists(filename): 203 | self.getInfoFromFile (filename) 204 | vprint (verbose, "Info file found : " + os.path.abspath(filename)) 205 | # Finds the data format ('dense', 'sparse', or 'sparse_binary') 206 | self.getFormatData(os.path.join(input_dir, basename + '_train.data')) 207 | else: 208 | vprint (verbose, "Info file NOT found : " + os.path.abspath(filename)) 209 | # Hopefully this never happens because this is done in a very inefficient way 210 | # reading the data multiple times... 211 | self.info['usage'] = 'No Info File' 212 | self.info['name'] = basename 213 | # Get the data format and sparsity 214 | self.getFormatData(os.path.join(input_dir, basename + '_train.data')) 215 | # Assume no categorical variable and no missing value (we'll deal with that later) 216 | self.info['has_categorical'] = 0 217 | self.info['has_missing'] = 0 218 | # Get the target number, label number, target type and task 219 | self.getTypeProblem(os.path.join(input_dir, basename + '_train.solution')) 220 | if self.info['task']=='regression': 221 | self.info['metric'] = 'r2_metric' 222 | else: 223 | self.info['metric'] = 'auc_metric' 224 | # Feature type: Numerical, Categorical, or Binary 225 | # Can also be determined from [filename].type 226 | self.info['feat_type'] = 'Mixed' 227 | # Get the number of features and patterns 228 | self.getNbrFeatures(os.path.join(input_dir, basename + '_train.data'), os.path.join(input_dir, basename + '_test.data'), os.path.join(input_dir, basename + '_valid.data')) 229 | self.getNbrPatterns(basename, input_dir, 'train') 230 | self.getNbrPatterns(basename, input_dir, 'valid') 231 | self.getNbrPatterns(basename, input_dir, 'test') 232 | # Set default time budget 233 | self.info['time_budget'] = 600 234 | return self.info 235 | 236 | def getInfoFromFile (self, filename): 237 | ''' Get all information {attribute = value} pairs from the public.info file''' 238 | with open (filename, "r") as info_file: 239 | lines = info_file.readlines() 240 | features_list = list(map(lambda x: tuple(x.strip("\'").split(" = ")), lines)) 241 | 242 | for (key, value) in features_list: 243 | self.info[key] = value.rstrip().strip("'").strip(' ') 244 | if self.info[key].isdigit(): # if we have a number, we want it to be an integer 245 | self.info[key] = int(self.info[key]) 246 | return self.info 247 | 248 | def getFormatData(self,filename): 249 | ''' Get the data format directly from the data file (in case we do not have an info file)''' 250 | if 'format' in self.info.keys(): 251 | return self.info['format'] 252 | if 'is_sparse' in self.info.keys(): 253 | if self.info['is_sparse'] == 0: 254 | self.info['format'] = 'dense' 255 | else: 256 | data = data_converter.read_first_line (filename) 257 | if ':' in data[0]: 258 | self.info['format'] = 'sparse' 259 | else: 260 | self.info['format'] = 'sparse_binary' 261 | else: 262 | data = data_converter.file_to_array (filename) 263 | if ':' in data[0][0]: 264 | self.info['is_sparse'] = 1 265 | self.info['format'] = 'sparse' 266 | else: 267 | nbr_columns = len(data[0]) 268 | for row in range (len(data)): 269 | if len(data[row]) != nbr_columns: 270 | self.info['format'] = 'sparse_binary' 271 | if 'format' not in self.info.keys(): 272 | self.info['format'] = 'dense' 273 | self.info['is_sparse'] = 0 274 | return self.info['format'] 275 | 276 | def getNbrFeatures (self, *filenames): 277 | ''' Get the number of features directly from the data file (in case we do not have an info file)''' 278 | if 'feat_num' not in self.info.keys(): 279 | self.getFormatData(filenames[0]) 280 | if self.info['format'] == 'dense': 281 | data = data_converter.file_to_array(filenames[0]) 282 | self.info['feat_num'] = len(data[0]) 283 | elif self.info['format'] == 'sparse': 284 | self.info['feat_num'] = 0 285 | for filename in filenames: 286 | sparse_list = data_converter.sparse_file_to_sparse_list (filename) 287 | last_column = [sparse_list[i][-1] for i in range(len(sparse_list))] 288 | last_column_feature = [a for (a,b) in last_column] 289 | self.info['feat_num'] = max(self.info['feat_num'], max(last_column_feature)) 290 | elif self.info['format'] == 'sparse_binary': 291 | self.info['feat_num'] = 0 292 | for filename in filenames: 293 | data = data_converter.file_to_array (filename) 294 | last_column = [int(data[i][-1]) for i in range(len(data))] 295 | self.info['feat_num'] = max(self.info['feat_num'], max(last_column)) 296 | return self.info['feat_num'] 297 | 298 | def getNbrPatterns (self, basename, info_dir, datatype): 299 | ''' Get the number of patterns directly from the data file (in case we do not have an info file)''' 300 | line_num = data_converter.num_lines(os.path.join(info_dir, basename + '_' + datatype + '.data')) 301 | self.info[datatype+'_num'] = line_num 302 | return line_num 303 | 304 | def getTypeProblem (self, solution_filename): 305 | ''' Get the type of problem directly from the solution file (in case we do not have an info file)''' 306 | if 'task' not in self.info.keys(): 307 | solution = np.array(data_converter.file_to_array(solution_filename)) 308 | target_num = solution.shape[1] 309 | self.info['target_num']=target_num 310 | if target_num == 1: # if we have only one column 311 | solution = np.ravel(solution) # flatten 312 | nbr_unique_values = len(np.unique(solution)) 313 | if nbr_unique_values < len(solution)/8: 314 | # Classification 315 | self.info['label_num'] = nbr_unique_values 316 | if nbr_unique_values == 2: 317 | self.info['task'] = 'binary.classification' 318 | self.info['target_type'] = 'Binary' 319 | else: 320 | self.info['task'] = 'multiclass.classification' 321 | self.info['target_type'] = 'Categorical' 322 | else: 323 | # Regression 324 | self.info['label_num'] = 0 325 | self.info['task'] = 'regression' 326 | self.info['target_type'] = 'Numerical' 327 | else: 328 | # Multilabel or multiclass 329 | self.info['label_num'] = target_num 330 | self.info['target_type'] = 'Binary' 331 | if any(item > 1 for item in map(np.sum,solution.astype(int))): 332 | self.info['task'] = 'multilabel.classification' 333 | else: 334 | self.info['task'] = 'multiclass.classification' 335 | return self.info['task'] -------------------------------------------------------------------------------- /logs/.gitignore: -------------------------------------------------------------------------------- 1 | [^.]* -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAP-archive/machine-learning-dgm/78786f0d9469cba201ad0108e4af2387574dc7c0/networks/__init__.py -------------------------------------------------------------------------------- /networks/net_DGMa.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.nn.parameter import Parameter 4 | import torch.nn.functional as F 5 | from utils.utils import weights_init_g, weights_init 6 | 7 | 8 | class BatchNorm2d(torch.nn.BatchNorm2d): 9 | def forward(self, input): 10 | self._check_input_dim(input) 11 | if self.training: 12 | momentum = self.momentum 13 | else: 14 | momentum = 0. 15 | return F.batch_norm( 16 | input, self.running_mean, self.running_var, self.weight, self.bias, 17 | self.training or not self.track_running_stats, momentum, self.eps) 18 | 19 | 20 | class ConvTranspose2d(nn.ConvTranspose2d): 21 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, 22 | groups=1, bias=True, 23 | num_tasks=1, out_size=None, batch_size=64): 24 | super(ConvTranspose2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, output_padding, 25 | groups, bias, dilation) 26 | 27 | def forward(self, inputx, d_in, d_out, output_size=None): 28 | output_padding = self._output_padding(inputx, output_size, self.stride, self.padding, self.kernel_size) 29 | out = F.conv_transpose2d(inputx, self.weight[:d_in, :d_out, :, :], self.bias, self.stride, self.padding, 30 | output_padding, groups=self.groups, dilation=self.dilation) 31 | return out 32 | 33 | def extand(self, input_channels, out_channels): 34 | w_old = self.weight.data.clone() 35 | 36 | self.out_channels += out_channels 37 | self.in_channels += input_channels 38 | self.weight = Parameter( 39 | torch.Tensor(self.in_channels, self.out_channels // self.groups, *self.kernel_size).cuda()) 40 | if self.bias is not None: 41 | b_old = self.bias.data.clone() 42 | self.bias = Parameter(torch.Tensor(self.in_channels).cuda()) 43 | self.apply(weights_init_g) 44 | self.weight.data[:w_old.shape[0]:, :w_old.shape[1], :, :].copy_(w_old) 45 | if self.bias is not None: 46 | self.bias.data[:b_old.shape[0]].copy_(b_old) 47 | 48 | 49 | class BatchNorm2d_plastic(nn.BatchNorm2d): 50 | def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): 51 | super(BatchNorm2d_plastic, self).__init__(num_features, eps, momentum, affine, track_running_stats) 52 | 53 | def forward(self, inputx, out_dim): 54 | self._check_input_dim(inputx) 55 | out = F.batch_norm(inputx, self.running_mean[:out_dim], self.running_var[:out_dim], self.weight[:out_dim], 56 | self.bias[:out_dim], 57 | self.training or not self.track_running_stats, self.momentum, self.eps) 58 | return out 59 | 60 | 61 | class netG(nn.Module): 62 | def __init__(self, nz, ngf, nc, smax, scalor=1, n_classes=1): 63 | super(netG, self).__init__() 64 | 65 | self.nz = nz 66 | self.gate = torch.nn.Sigmoid() 67 | 68 | self.nc = nc 69 | self.ngf = ngf 70 | self.scalor = scalor 71 | self.smax = smax 72 | self.ReLU = nn.ReLU(True) 73 | self.Tanh = nn.Tanh() 74 | self.conv1 = ConvTranspose2d(nz, self.scalor * ngf * 4, 4, 1, 0, bias=False) 75 | self.cap_conv1 = [self.scalor * ngf * 4] 76 | self.BatchNorms0 = torch.nn.ModuleList() 77 | self.BatchNorms0.append(torch.nn.BatchNorm2d(self.scalor * ngf * 4).apply(weights_init_g)) 78 | 79 | self.conv2 = ConvTranspose2d(self.scalor * ngf * 4, self.scalor * ngf * 2, 4, 2, 1, bias=False) 80 | self.cap_conv2 = [self.scalor * ngf * 2] 81 | self.BatchNorms1 = torch.nn.ModuleList() 82 | self.BatchNorms1.append(torch.nn.BatchNorm2d(self.scalor * ngf * 2).apply(weights_init_g)) 83 | 84 | self.conv3 = ConvTranspose2d(self.scalor * ngf * 2, self.scalor * ngf * 1, 4, 2, 1, bias=False) 85 | self.cap_conv3 = [self.scalor * ngf * 1] 86 | self.BatchNorms2 = torch.nn.ModuleList() 87 | self.BatchNorms2.append(torch.nn.BatchNorm2d(self.scalor * ngf * 1).apply(weights_init_g)) 88 | 89 | self.apply(weights_init) 90 | 91 | self.last = self.last = torch.nn.ModuleList() 92 | self.ec1 = torch.nn.Embedding(10, ngf * 4 * self.scalor) 93 | self.ec2 = torch.nn.Embedding(10, ngf * 2 * self.scalor) 94 | self.ec3 = torch.nn.Embedding(10, ngf * 1 * self.scalor) 95 | # self.ec4=torch.nn.Embedding(10,ngf * 1*self.scalor) 96 | 97 | self.ec1.weight.data.fill_(0) 98 | self.ec2.weight.data.fill_(0) 99 | self.ec3.weight.data.fill_(0) 100 | 101 | def forward(self, input, t, lables=None, s=1, t_mix=None): 102 | task = torch.autograd.Variable(torch.LongTensor([t]).cuda()) 103 | masks = self.mask(task, s=s) 104 | 105 | gc1, gc2, gc3 = masks 106 | x = self.conv1(input.view(-1, self.nz, 1, 1), self.nz, self.cap_conv1[t]) 107 | x = x * gc1[:, :self.cap_conv1[t]].view(1, -1, 1, 1).expand_as(x) 108 | x = self.BatchNorms0[t](x) # , self.cap_conv1[t]) 109 | x = self.ReLU(x) 110 | 111 | x = self.conv2(x, self.cap_conv1[t], self.cap_conv2[t]) 112 | x = x * gc2[:, :self.cap_conv2[t]].view(1, -1, 1, 1).expand_as(x) 113 | x = self.BatchNorms1[t](x) # , self.cap_conv2[t]) 114 | x = self.ReLU(x) 115 | 116 | x = self.conv3(x, self.cap_conv2[t], self.cap_conv3[t]) 117 | x = x * gc3[:, :self.cap_conv3[t]].view(1, -1, 1, 1).expand_as(x) 118 | x = self.BatchNorms2[t](x) # , self.cap_conv3[t]) 119 | x = self.ReLU(x) 120 | output = self.Tanh(self.last[t](x, self.cap_conv3[t], self.nc)) 121 | 122 | return output, masks 123 | 124 | def extand(self, t, extention, smax): 125 | print(extention) 126 | 127 | # self.BatchNorm0.extand(a) 128 | self.conv1.extand(0, extention[0]) 129 | self.BatchNorms0.append(torch.nn.BatchNorm2d(extention[0] + self.BatchNorms0[t].weight.shape[0]).cuda()) 130 | 131 | self.conv2.extand(extention[0], extention[1]) 132 | self.BatchNorms1.append(torch.nn.BatchNorm2d(extention[1] + self.BatchNorms1[t].weight.shape[0]).cuda()) 133 | 134 | self.conv3.extand(extention[1], extention[2]) 135 | self.BatchNorms2.append(torch.nn.BatchNorm2d(extention[2] + self.BatchNorms2[t].weight.shape[0]).cuda()) 136 | 137 | self.cap_conv1.append(self.conv1.weight.shape[1]) 138 | self.cap_conv2.append(self.conv2.weight.shape[1]) 139 | self.cap_conv3.append(self.conv3.weight.shape[1]) 140 | 141 | ec_1 = self.ec1.weight.data.clone() 142 | ec_2 = self.ec2.weight.data.clone() 143 | ec_3 = self.ec3.weight.data.clone() 144 | self.ec1 = torch.nn.Embedding(10, self.conv1.weight.shape[1]).cuda() 145 | self.ec2 = torch.nn.Embedding(10, self.conv2.weight.shape[1]).cuda() 146 | self.ec3 = torch.nn.Embedding(10, self.conv3.weight.shape[1]).cuda() 147 | self.ec1.weight.data.fill_(0) 148 | self.ec2.weight.data.fill_(0) 149 | self.ec3.weight.data.fill_(0) 150 | 151 | self.ec1.weight.data[:t + 1, :].fill_(-90) 152 | self.ec2.weight.data[:t + 1, :].fill_(-90) 153 | self.ec3.weight.data[:t + 1, :].fill_(-90) 154 | 155 | self.ec1.weight.data[:t + 1, :ec_1.shape[1]].copy_(ec_1[:t + 1, :]) 156 | self.ec2.weight.data[:t + 1, :ec_2.shape[1]].copy_(ec_2[:t + 1, :]) 157 | self.ec3.weight.data[:t + 1, :ec_3.shape[1]].copy_(ec_3[:t + 1, :]) 158 | 159 | return [self.conv1.weight.shape, self.conv2.weight.shape, self.conv3.weight.shape] 160 | 161 | def mask(self, t, s=1, test=False): 162 | gc1 = self.gate(s * self.ec1(t)) 163 | gc2 = self.gate(s * self.ec2(t)) 164 | gc3 = self.gate(s * self.ec3(t)) 165 | return [gc1, gc2, gc3] 166 | 167 | def get_view_for(self, n, masks): 168 | gc1, gc2, gc3 = masks 169 | if n == 'conv1.weight': 170 | return gc1.data.view(1, -1, 1, 1).expand_as(self.conv1.weight) 171 | elif n == 'conv1.bias': 172 | return gc1.data.view(-1) 173 | elif n == 'conv2.weight': 174 | post = gc2.data.view(1, -1, 1, 1).expand_as(self.conv2.weight) 175 | pre = gc1.data.view(-1, 1, 1, 1).expand_as(self.conv2.weight) 176 | return torch.min(post, pre) 177 | elif n == 'conv2.bias': 178 | return gc2.data.view(-1) 179 | elif n == 'conv3.weight': 180 | post = gc3.data.view(1, -1, 1, 1).expand_as(self.conv3.weight) 181 | pre = gc2.data.view(-1, 1, 1, 1).expand_as(self.conv3.weight) 182 | return torch.min(post, pre) 183 | elif n == 'conv3.bias': 184 | return gc3.data.view(-1) 185 | return None 186 | 187 | 188 | class netD(nn.Module): 189 | def __init__(self, ndf, nc, out_class=1): 190 | super(netD, self).__init__() 191 | self.nc = nc 192 | self.momentum = 0.1 193 | self.LeakyReLU = nn.LeakyReLU(0.2, inplace=True) 194 | self.conv1 = nn.Conv2d(nc, ndf, 5, 2, 2, bias=False) 195 | self.conv2 = nn.Conv2d(ndf, ndf * 2, 5, 2, 2, bias=False) 196 | self.BatchNorm2 = BatchNorm2d(ndf * 2, momentum=self.momentum, track_running_stats=True) 197 | self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, 5, 2, 2, bias=False) 198 | 199 | self.BatchNorm3 = BatchNorm2d(ndf * 4, momentum=self.momentum, track_running_stats=True) 200 | self.output_size = ndf * 4 * 4 * 4 201 | self.disc_linear = nn.Linear(self.output_size, 1) # .append(nn.Linear(ndf, 1)) 202 | self.aux_linear = nn.Linear(self.output_size, out_class) 203 | 204 | self.softmax = nn.Softmax(dim=1) 205 | self.sigmoid = nn.Sigmoid() 206 | self.ndf = ndf 207 | self.apply(weights_init) 208 | 209 | def forward(self, input): 210 | batch_size = input.size()[0] 211 | x = self.conv1(input) 212 | x = self.LeakyReLU(x) 213 | x = self.conv2(x) 214 | x = self.BatchNorm2(x) 215 | x = self.LeakyReLU(x) 216 | x = self.conv3(x) 217 | x = self.BatchNorm3(x) 218 | x = self.LeakyReLU(x) 219 | x = x.view(batch_size, -1) 220 | c = self.aux_linear(x) 221 | s = self.disc_linear(x) 222 | return s.view(-1), c -------------------------------------------------------------------------------- /networks/net_DGMw.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 SAP SE 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | import torch.nn as nn 14 | import torch 15 | import numpy as np 16 | from torch.nn.parameter import Parameter 17 | import math 18 | import torch.nn.functional as F 19 | from utils.utils import weights_init, weights_init_g 20 | 21 | 22 | class Plastic_ConvTranspose2d(nn.ConvTranspose2d): 23 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, 24 | groups=1, bias=True, 25 | num_tasks=1, out_size=None, batch_size=64, smax=800): 26 | super(Plastic_ConvTranspose2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, 27 | output_padding, groups, bias, dilation) 28 | if bias: 29 | self.ec = torch.nn.Embedding(num_tasks, 30 | out_channels * ((in_channels * kernel_size * kernel_size) + 1)).cuda() 31 | else: 32 | self.ec = torch.nn.Embedding(num_tasks, out_channels * ((in_channels * kernel_size * kernel_size))).cuda() 33 | 34 | self.smax = smax 35 | self.ec.weight.data.fill_(0) 36 | self.prev_weight_shape = self.weight.shape 37 | 38 | def forward(self, inputx, mask, d_in, d_out, output_size=None): 39 | self.prev_weight_shape = self.weight.shape 40 | output_padding = self._output_padding(inputx, output_size, self.stride, self.padding, self.kernel_size) 41 | bias = None 42 | if mask is not None: 43 | if not self.bias is None: 44 | bias = self.bias[:d_out] * mask[:, 0].contiguous().view(-1) 45 | mask_ = mask[:, 1:] 46 | else: 47 | mask_ = mask[:, :] 48 | out = F.conv_transpose2d(inputx, self.weight[:d_in, :d_out, :, :] * mask_.contiguous().view( 49 | self.weight.data.shape)[:d_in, :d_out, :, :], 50 | bias, self.stride, self.padding, output_padding, groups=self.groups, 51 | dilation=self.dilation) 52 | else: 53 | if not self.bias is None: 54 | bias = self.bias[:d_out] 55 | out = F.conv_transpose2d(inputx, self.weight[:d_in, :d_out, :, :], bias, self.stride, self.padding, 56 | output_padding, groups=self.groups, dilation=self.dilation) 57 | return out 58 | 59 | def expand(self, input_channels, out_channels): 60 | w_old = self.weight.data.clone() 61 | if self.bias is not None: 62 | b_old = self.bias.data.clone() 63 | self.out_channels += out_channels 64 | self.in_channels += input_channels 65 | self.weight = Parameter( 66 | torch.Tensor(self.in_channels, self.out_channels // self.groups, *self.kernel_size).cuda()) 67 | if self.bias is not None: 68 | self.bias = Parameter(torch.Tensor(self.out_channels)) 69 | self.weight.data.fill_(0) 70 | self.apply(weights_init) 71 | self.weight.data[:w_old.shape[0]:, :w_old.shape[1], :, :].copy_(w_old) 72 | if self.bias is not None: 73 | self.bias.data[:b_old.shape[0]].copy_(b_old) 74 | return self.weight.shape 75 | 76 | def expand_embeddings(self, n_new_classes, mask_pre=None): 77 | ec = self.ec.weight.view([-1] + list(self.prev_weight_shape)).data.clone() 78 | new_dim = self.out_channels * ((self.in_channels * self.kernel_size[0] * self.kernel_size[1])) 79 | if self.bias is not None: 80 | new_dim = self.out_channels * ((self.in_channels * self.kernel_size[0] * self.kernel_size[1]) + 1) 81 | self.ec = torch.nn.Embedding(ec.shape[0] + n_new_classes, new_dim).cuda() 82 | self.ec.weight.data.fill_(0) 83 | 84 | if ec.shape[0] > 0: 85 | self.ec.weight.data[:ec.shape[0], :].fill_(-90) # for generating old samples do not use newly added parameters 86 | self.ec.weight.view([-1] + list(self.weight.shape)).data[:ec.shape[0], :ec.shape[1], :ec.shape[2], :, 87 | :].copy_(ec[:, :, :, :, :]) # but only the reserved once 88 | return self.prev_weight_shape 89 | 90 | 91 | class BatchNorm2d(torch.nn.BatchNorm2d): 92 | def forward(self, input): 93 | self._check_input_dim(input) 94 | if self.training: 95 | momentum = self.momentum 96 | else: 97 | momentum = 0. 98 | return F.batch_norm( 99 | input, self.running_mean, self.running_var, self.weight, self.bias, 100 | self.training or not self.track_running_stats, momentum, self.eps) 101 | 102 | 103 | class netG(nn.Module): 104 | def __init__(self, nz, ngf, nc, smax, n_classes=1): 105 | super(netG, self).__init__() 106 | 107 | self.gate = torch.nn.Sigmoid() 108 | 109 | self.nc = nc 110 | self.nz = nz 111 | self.ngf = ngf 112 | self.scalor = 1 113 | self.smax = smax 114 | self.ReLU = nn.ReLU(True) 115 | self.Tanh = nn.Tanh() 116 | 117 | self.conv1 = Plastic_ConvTranspose2d(nz, ngf * 4 * self.scalor, 4, 1, 0, bias=False, num_tasks=n_classes, 118 | smax=smax) 119 | # self.conv1 = Plastic_ConvTranspose2d(nz, ngf * 8*self.scalor, 3, 1, 0, bias=False, num_tasks=n_classes, smax=smax) 120 | 121 | self.cap_conv1 = [ngf * 4 * self.scalor] 122 | self.BatchNorms1 = torch.nn.ModuleList() 123 | self.BatchNorms1.append(torch.nn.BatchNorm2d(ngf * 4 * self.scalor).apply(weights_init)) 124 | 125 | self.conv2 = Plastic_ConvTranspose2d(ngf * 4 * self.scalor, ngf * 2 * self.scalor, 4, 2, 1, bias=False, 126 | num_tasks=n_classes, smax=smax) 127 | self.cap_conv2 = [ngf * 2 * self.scalor] 128 | self.BatchNorms2 = torch.nn.ModuleList() 129 | self.BatchNorms2.append(torch.nn.BatchNorm2d(ngf * 2 * self.scalor).apply(weights_init)) 130 | 131 | self.conv3 = Plastic_ConvTranspose2d(ngf * 2 * self.scalor, ngf * 1 * self.scalor, 4, 2, 1, bias=False, 132 | num_tasks=n_classes, smax=smax) 133 | self.cap_conv3 = [ngf * 1 * self.scalor] 134 | self.BatchNorms3 = torch.nn.ModuleList() 135 | self.BatchNorms3.append(torch.nn.BatchNorm2d(ngf * 1 * self.scalor).apply(weights_init)) 136 | 137 | self.apply(weights_init_g) 138 | self.last = self.last = torch.nn.ModuleList() 139 | 140 | def extand(self, extention): 141 | ws_0 = self.conv1.expand(0, math.ceil( 142 | extention[0] / (self.nz * self.conv1.kernel_size[0] * self.conv1.kernel_size[1]))) 143 | 144 | n_in_conv2 = self.conv1.weight.shape[1] 145 | n_out_conv2 = self.conv2.weight.shape[1] 146 | n_params_conv2 = n_in_conv2 * n_out_conv2 * self.conv2.kernel_size[0] * self.conv2.kernel_size[1] 147 | n_add_out_conv2 = max(math.ceil((extention[1] - (n_params_conv2 - np.prod(self.conv2.weight.size()).item())) / ( 148 | self.conv1.weight.data.shape[1] * self.conv2.kernel_size[0] * self.conv2.kernel_size[1])), 0) 149 | 150 | ws_1 = self.conv2.expand( 151 | math.ceil(extention[0] / (self.nz * self.conv1.kernel_size[0] * self.conv1.kernel_size[1])), 152 | n_add_out_conv2) 153 | n_in_conv3 = self.conv2.weight.shape[1] 154 | n_out_conv3 = self.conv3.weight.shape[1] 155 | n_params_conv3 = n_in_conv3 * n_out_conv3 * self.conv3.kernel_size[0] * self.conv3.kernel_size[1] 156 | n_add_out_conv3 = max(math.ceil((extention[2] - (n_params_conv3 - np.prod(self.conv3.weight.size()).item())) / ( 157 | self.conv2.weight.data.shape[1] * self.conv3.kernel_size[0] * self.conv3.kernel_size[1])), 0) 158 | 159 | ws_2 = self.conv3.expand(n_add_out_conv2, n_add_out_conv3) 160 | self.cap_conv1.append(self.conv1.weight.shape[1]) 161 | self.cap_conv2.append(self.conv2.weight.shape[1]) 162 | self.cap_conv3.append(self.conv3.weight.shape[1]) 163 | self.BatchNorms1.append(torch.nn.BatchNorm2d(self.conv1.weight.shape[1]).cuda()) 164 | self.BatchNorms2.append(torch.nn.BatchNorm2d(self.conv2.weight.shape[1]).cuda()) 165 | self.BatchNorms3.append(torch.nn.BatchNorm2d(self.conv3.weight.shape[1]).cuda()) 166 | 167 | return [ws_0, ws_1, ws_2] 168 | 169 | def expand_embeddings(self, n_new_classes): 170 | prev_weight_shape_0 = self.conv1.expand_embeddings(n_new_classes) 171 | prev_weight_shape_1 = self.conv2.expand_embeddings(n_new_classes) 172 | prev_weight_shape_2 = self.conv3.expand_embeddings(n_new_classes) 173 | return [prev_weight_shape_0, prev_weight_shape_1, prev_weight_shape_2] 174 | 175 | def forward(self, input, t, lables=None, s=1, t_mix=None): 176 | task = torch.autograd.Variable(torch.LongTensor([t]).cuda()) 177 | masks = self.mask(task, None, s=s) 178 | 179 | gc1, gc2, gc3 = masks 180 | # print(input.shape) 181 | x = self.conv1(input, gc1, self.nz, self.cap_conv1[t]) 182 | x = self.BatchNorms1[t](x) 183 | x = self.ReLU(x) 184 | 185 | x = self.conv2(x, gc2, self.cap_conv1[t], self.cap_conv2[t]) 186 | x = self.BatchNorms2[t](x) 187 | x = self.ReLU(x) 188 | 189 | x = self.conv3(x, gc3, self.cap_conv2[t], self.cap_conv3[t]) 190 | x = self.BatchNorms3[t](x) 191 | x = self.ReLU(x) 192 | output = self.Tanh(self.last[t](x, None, self.cap_conv3[t], self.nc)) 193 | return output, masks 194 | 195 | def mask(self, t, labels, s=1, test=False): 196 | gc1 = self.gate(s * self.conv1.ec(t)) # .view(self.conv1.out_channels, ( 197 | # self.conv1.in_channels * self.conv1.kernel_size[0] * self.conv1.kernel_size[1]))# + 1)) 198 | gc2 = self.gate(s * self.conv2.ec(t)) # .view(self.conv2.out_channels, ( 199 | # self.conv2.in_channels * self.conv2.kernel_size[0] * self.conv2.kernel_size[1]))# + 1)) 200 | gc3 = self.gate(s * self.conv3.ec(t)) # .view(self.conv3.out_channels, ( 201 | # self.conv3.in_channels * self.conv3.kernel_size[0] * self.conv3.kernel_size[1]))# + 1)) 202 | # gc4 = self.gate(s * self.conv4.ec(t))#.view(self.conv4.out_channels, ( 203 | # #self.conv4.in_channels * self.conv4.kernel_size[0] * self.conv4.kernel_size[1]))# + 1)) 204 | return [gc1, gc2, gc3] # ,gc4] 205 | 206 | def get_total_mask(self, t, labels, s): 207 | task = torch.autograd.Variable(torch.LongTensor(labels.data.cpu().numpy()).cuda()) 208 | # t = torch.autograd.Variable(torch.LongTensor([t]).cuda()) 209 | masks = self.mask(task, None, s=s) 210 | m0 = torch.max(masks[0], 0)[0].view(1, -1) 211 | m1 = torch.max(masks[1], 0)[0].view(1, -1) 212 | m2 = torch.max(masks[2], 0)[0].view(1, -1) 213 | return [m0, m1, m2] # + m3 214 | 215 | def get_view_for(self, n, masks): 216 | gc1, gc2, gc3 = masks 217 | if n == 'conv1.weight': 218 | return gc1.data[:, :].contiguous().view(self.conv1.weight.shape) 219 | elif n == 'conv1.bias': 220 | return gc1.data[:, 0].contiguous().view(-1) 221 | 222 | elif n == 'conv2.weight': 223 | return gc2.data[:, :].contiguous().view(self.conv2.weight.shape) 224 | elif n == 'conv2.bias': 225 | return gc2.data[:, 0].contiguous().view(-1) 226 | 227 | elif n == 'conv3.weight': 228 | return gc3.data[:, :].contiguous().view(self.conv3.weight.shape) 229 | elif n == 'conv3.bias': 230 | return gc3.data[:, 0].contiguous().view(-1) 231 | return None 232 | 233 | 234 | class netD(nn.Module): 235 | def __init__(self, ndf, nc, out_class=1): 236 | super(netD, self).__init__() 237 | self.nc = nc 238 | self.momentum = 0.1 239 | self.LeakyReLU = nn.LeakyReLU(0.2, inplace=True) 240 | self.conv1 = nn.Conv2d(nc, ndf, 5, 2, 2, bias=False) 241 | self.conv2 = nn.Conv2d(ndf, ndf * 2, 5, 2, 2, bias=False) 242 | self.BatchNorm2 = BatchNorm2d(ndf * 2, momentum=self.momentum, track_running_stats=True) 243 | # self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False) 244 | self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, 5, 2, 2, bias=False) 245 | 246 | self.BatchNorm3 = BatchNorm2d(ndf * 4, momentum=self.momentum, track_running_stats=True) 247 | self.output_size = ndf * 4 * 4 * 4 248 | self.disc_linear = nn.Linear(self.output_size, 1) # .append(nn.Linear(ndf, 1)) 249 | self.aux_linear = nn.Linear(self.output_size, out_class) 250 | 251 | self.softmax = nn.Softmax(dim=1) 252 | self.sigmoid = nn.Sigmoid() 253 | self.ndf = ndf 254 | self.apply(weights_init) 255 | 256 | def forward(self, input): 257 | batch_size = input.size()[0] 258 | x = self.conv1(input) 259 | x = self.LeakyReLU(x) 260 | x = self.conv2(x) 261 | x = self.BatchNorm2(x) 262 | x = self.LeakyReLU(x) 263 | x = self.conv3(x) 264 | x = self.BatchNorm3(x) 265 | x = self.LeakyReLU(x) 266 | x = x.view(batch_size, -1) 267 | c = self.aux_linear(x) 268 | s = self.disc_linear(x) 269 | return s.view(-1), c 270 | -------------------------------------------------------------------------------- /networks/net_DGMw_imnet.py: -------------------------------------------------------------------------------- 1 | #Copyright 2019 SAP SE 2 | #Licensed under the Apache License, Version 2.0 (the "License"); 3 | #you may not use this file except in compliance with the License. 4 | #You may obtain a copy of the License at 5 | 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | 8 | #Unless required by applicable law or agreed to in writing, software 9 | #distributed under the License is distributed on an "AS IS" BASIS, 10 | #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | #See the License for the specific language governing permissions and 12 | #limitations under the License. 13 | import torch.nn as nn 14 | import torch 15 | import numpy as np 16 | from copy import copy 17 | import torch.nn.functional as F 18 | from networks.resnet import resnet18 19 | import math 20 | from torch.nn.parameter import Parameter 21 | from utils.utils import weights_init_g 22 | 23 | 24 | class Linear_extandable(nn.Linear): 25 | def __init__(self, in_features, out_features, num_tasks=1, bias=True, smax=1000, device=None): 26 | super(Linear_extandable, self).__init__(in_features, out_features, bias) 27 | self.in_features = in_features 28 | self.out_features = out_features 29 | self.smax = smax 30 | self.device = device 31 | # self.reset_parameters() 32 | self.ec = torch.nn.Embedding(1, (self.in_features) * self.out_features) 33 | self.ec_b = None 34 | if self.bias is not None: 35 | self.ec_b = torch.nn.Embedding(1, self.out_features) 36 | 37 | self.s = torch.nn.ParameterList() 38 | self.ec.weight.data.fill_(0)#6/self.smax) 39 | self.ec_past = torch.sparse.FloatTensor(10, self.out_features, self.in_features) 40 | 41 | self.prev_weight_shape = self.weight.shape 42 | 43 | def forward(self, inputx, mask, d_in, d_out, output_size=None): 44 | if mask is not None: 45 | bias = None 46 | if self.bias is not None: 47 | bias = self.bias[:d_out] * mask[1][:d_out] 48 | out = F.linear(inputx, self.weight[:d_out, :d_in] * mask[0][:d_out, :d_in], bias) 49 | else: 50 | out = F.linear(inputx, self.weight[:d_out, :d_in], self.bias[:d_out]) 51 | 52 | return out 53 | 54 | def extand(self, delta_in_features, delta_out_features): 55 | w_old = self.weight.data.clone() 56 | b_old = None 57 | if self.bias is not None: 58 | b_old = self.bias.data.clone() 59 | self.out_features += delta_out_features[0] 60 | self.out_features = int(16 * math.ceil(self.out_features / 16.)) 61 | self.in_features += delta_in_features 62 | 63 | self.weight = Parameter(torch.Tensor(self.out_features, self.in_features).cuda(self.device)) 64 | 65 | if self.bias is not None: 66 | self.bias = Parameter(torch.Tensor(self.out_features).cuda(self.device)) 67 | 68 | self.apply(weights_init_g) 69 | self.weight.data[:w_old.shape[0], :w_old.shape[1]].copy_(w_old) 70 | if self.bias is not None: 71 | self.bias.data[:b_old.shape[0]].copy_(b_old) 72 | 73 | del (w_old) 74 | del (b_old) 75 | torch.cuda.empty_cache() 76 | return self.weight.shape 77 | 78 | def expand_embeddings(self, n_new_classes, t, mask): 79 | # extand amd store the masks 80 | a = self.ec_past.to_dense().cpu() 81 | a[t] = mask[0] 82 | self.ec_past = torch.sparse.FloatTensor((a == 1).nonzero().t(), 83 | torch.ones((a == 1).nonzero().shape[0]), 84 | torch.Size([10, self.weight.shape[0], self.weight.shape[1]])).cuda(self.device) 85 | del (a) 86 | torch.cuda.empty_cache() 87 | self.ec = torch.nn.Embedding(1, self.weight.shape[0] * self.weight.shape[1]).cuda(self.device) 88 | self.ec.weight.data.fill_(6/self.smax) 89 | self.prev_weight_shape = self.weight.shape 90 | 91 | class Plastic_Conv2d(nn.Conv2d): 92 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, 93 | num_tasks=1, smax=1000, device=None): 94 | super(Plastic_Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, 95 | bias) 96 | 97 | self.ec_bias = None 98 | self.device = device 99 | if bias: 100 | self.ec_bias = torch.nn.Embedding(num_tasks, out_channels) # .cuda() 101 | self.ec_bias.weight.data.fill_(0) 102 | 103 | self.ec = torch.nn.Embedding(1, out_channels * ((in_channels * kernel_size * kernel_size))) # .cuda() 104 | self.smax = smax 105 | self.ec.weight.data.fill_(0) # 6/self.smax) 106 | self.ec_past = torch.sparse.FloatTensor(10, self.weight.shape[0], self.weight.shape[1], self.weight.shape[2], 107 | self.weight.shape[3]) 108 | self.prev_weight_shape = self.weight.shape 109 | 110 | def forward(self, inputx, mask, d_in, d_out): 111 | # self.prev_weight_shape=self.weight.shape 112 | bias = None 113 | if mask is not None: 114 | mask_ = mask[0] 115 | if self.bias is not None: 116 | mask_b = mask[1] 117 | bias = self.bias[:d_out] * mask_b[:d_out] 118 | out = F.conv2d(inputx, self.weight[:d_out, :d_in, :, :] * mask_[:d_out, :d_in, :, :], bias, self.stride, 119 | self.padding, groups=self.groups, dilation=self.dilation) 120 | 121 | else: 122 | if self.bias is not None: 123 | bias = self.bias[:d_out] 124 | out = F.conv2d(inputx, self.weight[:d_out, :d_in, :, :], bias, self.stride, self.padding, 125 | groups=self.groups, dilation=self.dilation) 126 | return out 127 | 128 | def expand(self, input_channels, out_channels): 129 | b_old = None 130 | w_old = self.weight.data.clone() 131 | if self.bias is not None: 132 | b_old = self.bias.data.clone() 133 | self.out_channels += out_channels 134 | self.in_channels += input_channels 135 | self.weight = Parameter( 136 | torch.Tensor(self.out_channels, (self.in_channels // self.groups), *self.kernel_size).cuda(self.device)) 137 | if self.bias is not None: 138 | self.bias = Parameter(torch.Tensor(self.out_channels)).cuda(self.device) 139 | self.apply(weights_init_g) 140 | self.weight.data[:w_old.shape[0]:, :w_old.shape[1], :, :].copy_(w_old) 141 | if self.bias is not None: 142 | self.bias.data[:b_old.shape[0]].copy_(b_old) 143 | del(w_old) 144 | torch.cuda.empty_cache() 145 | 146 | 147 | return self.weight.shape 148 | 149 | def expand_embeddings(self, n_new_classes, t, mask): 150 | a = self.ec_past.to_dense().cpu() 151 | a[t] = mask[0] 152 | if (a == 1).nonzero().shape[0] == 0: 153 | self.ec_past = torch.sparse.FloatTensor(10, self.weight.shape[0], self.weight.shape[1], 154 | self.weight.shape[2], self.weight.shape[3]).cuda(self.device) 155 | else: 156 | self.ec_past = torch.sparse.FloatTensor((a == 1).nonzero().t(), 157 | torch.ones((a == 1).nonzero().shape[0]), 158 | torch.Size([10, self.weight.shape[0], self.weight.shape[1], 159 | self.weight.shape[2], self.weight.shape[3]])).cuda( 160 | self.device) 161 | del (a) 162 | torch.cuda.empty_cache() 163 | self.ec = torch.nn.Embedding(1, self.out_channels * ( 164 | (self.in_channels * self.kernel_size[0] * self.kernel_size[0]))).cuda(self.device) 165 | self.ec.weight.data.fill_(6/self.smax) 166 | self.prev_weight_shape = self.weight.shape 167 | 168 | def avg_pool2d(x): 169 | '''Twice differentiable implementation of 2x2 average pooling.''' 170 | return (x[:, :, ::2, ::2] + x[:, :, 1::2, ::2] + x[:, :, ::2, 1::2] + x[:, :, 1::2, 1::2]) / 4 171 | 172 | class GeneratorBlock(nn.Module): 173 | '''ResNet-style block for the generator model.''' 174 | def __init__(self, in_chans, out_chans, smax, upsample=False, device="cuda"): 175 | super(GeneratorBlock, self).__init__() 176 | 177 | self.gate = torch.nn.Sigmoid() 178 | self.in_chans = in_chans 179 | self.out_chans = out_chans 180 | self.device = device 181 | 182 | self.upsample = upsample 183 | self.shortcut_conv = Plastic_Conv2d(self.in_chans, self.out_chans, kernel_size=1, bias=False, 184 | device=self.device) 185 | if self.in_chans != self.out_chans: 186 | self.cap_shortcut = [self.out_chans] 187 | else: 188 | self.cap_shortcut = [None] 189 | self.BatchNorm1s = torch.nn.ModuleList() 190 | self.BatchNorm1s.append(torch.nn.BatchNorm2d(in_chans)) 191 | self.conv1 = Plastic_Conv2d(in_chans, in_chans, kernel_size=3, padding=1, bias=False, device=self.device) 192 | self.cap_conv1 = [in_chans] 193 | self.BatchNorm2s = torch.nn.ModuleList() 194 | self.BatchNorm2s.append(torch.nn.BatchNorm2d(in_chans)) 195 | self.conv2 = Plastic_Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False, device=self.device) 196 | self.cap_conv2 = [out_chans] 197 | 198 | def extand(self, t, inp_dim, c1, c2, c_s, smax): 199 | print("inp_dim", inp_dim) 200 | 201 | n_out_conv1 = self.conv1.weight.data.shape[0] 202 | n_params_conv1 = inp_dim * n_out_conv1 * self.conv1.kernel_size[0] * self.conv1.kernel_size[1] # n_params after adding input channels 203 | n_add_out_conv1 = max(math.ceil((c1[0] - (n_params_conv1 - np.prod(self.conv1.weight.size()).item())) / 204 | (self.conv1.weight.data.shape[1] * self.conv1.kernel_size[0] * self.conv1.kernel_size[1])), 0) 205 | delta_in_conv_1 = inp_dim - self.in_chans 206 | 207 | ws_0 = self.conv1.expand(delta_in_conv_1, n_add_out_conv1) 208 | 209 | n_in_conv2 = self.conv1.weight.shape[0] 210 | n_out_conv2 = self.conv2.weight.data.shape[0] 211 | n_params_conv2 = n_in_conv2 * n_out_conv2 * self.conv2.kernel_size[0] * self.conv2.kernel_size[1] # n_params after adding input channels 212 | n_add_out_conv2 = max(math.ceil((c2[0] - (n_params_conv2 - np.prod(self.conv2.weight.size()).item())) / 213 | (self.conv2.weight.data.shape[1] * self.conv2.kernel_size[0] * 214 | self.conv2.kernel_size[1])), 0) 215 | 216 | ws_1 = self.conv2.expand(n_add_out_conv1, n_add_out_conv2) 217 | self.BatchNorm1s.append(torch.nn.BatchNorm2d(self.conv1.in_channels).cuda(self.device)) 218 | delta_conv2 = 0 219 | self.in_chans += delta_in_conv_1 220 | self.out_chans += n_add_out_conv2 221 | ws_s = self.shortcut_conv.weight.shape 222 | 223 | if self.in_chans != self.out_chans: 224 | n_in_conv_sc = inp_dim 225 | n_out_conv_sc = self.shortcut_conv.weight.data.shape[0] 226 | n_params_conv_sc = n_in_conv_sc * n_out_conv_sc * self.shortcut_conv.weight.shape[0] * self.shortcut_conv.weight.shape[1] # n_params after adding input channels 227 | n_add_out_conv_sc = max(math.ceil((c_s[0] - (n_params_conv_sc - np.prod(self.shortcut_conv.weight.size()).item())) / 228 | (self.shortcut_conv.weight.data.shape[1] * self.shortcut_conv.kernel_size[0] * 229 | self.shortcut_conv.kernel_size[1])), 0) 230 | ws_s = self.shortcut_conv.expand((self.conv1.weight.shape[1] - self.shortcut_conv.weight.shape[1]), n_add_out_conv_sc) 231 | 232 | if self.shortcut_conv.weight.shape[0] > self.conv2.weight.shape[0]: 233 | delta_conv2 = self.shortcut_conv.weight.shape[0] - self.conv2.weight.shape[0] 234 | _ = self.conv2.expand(0, delta_conv2) 235 | 236 | else: 237 | self.shortcut_conv.expand(0, self.conv2.weight.shape[0] - self.shortcut_conv.weight.shape[0]) 238 | self.cap_shortcut.append(self.shortcut_conv.out_channels) 239 | 240 | else: 241 | self.cap_shortcut.append(None) 242 | 243 | self.out_chans += delta_conv2 244 | self.BatchNorm2s.append(torch.nn.BatchNorm2d(self.conv2.in_channels).cuda(self.device)) 245 | self.cap_conv1.append(self.conv1.weight.shape[0]) 246 | self.cap_conv2.append(self.conv2.weight.shape[0]) 247 | torch.cuda.empty_cache() 248 | return ws_1[0] + delta_conv2 249 | 250 | def expand_embeddings(self, n_new_classes, t, mask): 251 | self.conv1.expand_embeddings(n_new_classes, t, mask[0]) 252 | self.conv2.expand_embeddings(n_new_classes, t, mask[1]) 253 | self.shortcut_conv.expand_embeddings(n_new_classes, t, mask[2]) 254 | 255 | prev_weight_shape_0 = copy(self.conv1.prev_weight_shape) 256 | prev_weight_shape_1 = copy(self.conv2.prev_weight_shape) 257 | prev_weight_shape_s = copy(self.shortcut_conv.prev_weight_shape) 258 | self.conv1.prev_weight_shape = self.conv1.weight.shape 259 | self.conv2.prev_weight_shape = self.conv2.weight.shape 260 | self.shortcut_conv.prev_weight_shape = self.shortcut_conv.weight.shape 261 | return [prev_weight_shape_0, prev_weight_shape_1, prev_weight_shape_s] # , prev_weight_shape_2]#, prev_weight_shape_3] 262 | 263 | def forward(self, input_, cap_prev, task, s=1, past_generation=False): 264 | x = input_ 265 | t = task 266 | if not past_generation: 267 | masks = self.mask(task, s=s) 268 | else: 269 | masks = self.eval_masks(task) 270 | gc1, gc2, gc_s = masks 271 | if self.upsample: 272 | shortcut = nn.functional.upsample(x, scale_factor=2, mode='nearest') 273 | else: 274 | shortcut = x 275 | if self.cap_shortcut[t] is not None: 276 | shortcut = self.shortcut_conv(shortcut, gc_s, cap_prev, self.cap_shortcut[t]) 277 | x = self.BatchNorm1s[t](x) # (x, cap_prev) 278 | x = nn.functional.relu(x, inplace=True) 279 | if self.upsample: 280 | x = nn.functional.upsample(x, scale_factor=2, mode='nearest') 281 | x = self.conv1(x, gc1, cap_prev, self.cap_conv1[t]) 282 | x = self.BatchNorm2s[t](x) # (x,self.cap_conv1[t]) 283 | x = nn.functional.relu(x, inplace=True) 284 | x = self.conv2(x, gc2, self.cap_conv1[t], self.cap_conv2[t]) 285 | 286 | return x + shortcut, self.cap_conv2, masks 287 | 288 | def eval_masks(self, task): 289 | gc1 = self.conv1.ec_past.to_dense()[task].squeeze(0) 290 | gc2 = self.conv2.ec_past.to_dense()[task].squeeze(0) 291 | gcs = self.shortcut_conv.ec_past.to_dense()[task].squeeze(0) 292 | return [[gc1, None], [gc2, None], [gcs, None]] 293 | 294 | def mask(self, t, s=1, test=False): 295 | t = torch.autograd.Variable(torch.LongTensor([0]).cuda(self.device)) 296 | gc1 = self.gate(s * self.conv1.ec(t)).contiguous().view(self.conv1.weight.size()) 297 | gc1_b = None 298 | if self.conv1.ec_bias is not None: 299 | gc1_b = self.gate(s * self.conv1.ec_bias(t)).view(-1) 300 | 301 | gc2 = self.gate(s * self.conv2.ec(t)).view(self.conv2.weight.size()) 302 | gc2_b = None 303 | if self.conv2.ec_bias is not None: 304 | gc2_b = self.gate(s * self.conv2.ec_bias(t)).view(-1) 305 | 306 | gc_s = self.gate(s * self.shortcut_conv.ec(t)).view(self.shortcut_conv.weight.size()) 307 | gcs_b = None 308 | if self.shortcut_conv.ec_bias is not None: 309 | gcs_b = self.gate(s * self.shortcut_conv.ec_bias(t)).view(-1) 310 | return [[gc1, gc1_b], [gc2, gc2_b], [gc_s, gc2_b]] 311 | 312 | def get_view_for(self, n, masks): 313 | gc1, gc2, gc_s = masks 314 | if n.endswith('conv1.weight'): 315 | return gc1[0][:, :].contiguous().view(self.conv1.weight.shape) 316 | elif n.endswith('conv1.bias'): 317 | gc1 = gc1[1][:, :].contiguous().view(self.conv1.bias.shape) 318 | return gc1.data.view(-1) 319 | elif n.endswith('conv2.weight'): 320 | return gc2[0][:, :].contiguous().view(self.conv2.weight.shape) 321 | elif n.endswith('conv2.bias'): 322 | gc2 = gc2[1][:, :].contiguous().view(self.conv2.bias.shape) 323 | return gc2.data.view(-1) 324 | elif n.endswith('shortcut_conv.weight'): 325 | return gc_s[0][:, :].contiguous().view(self.shortcut_conv.weight.shape) 326 | elif n.endswith('shortcut_conv.bias'): 327 | gc_s = gc_s[1][:, :].contiguous().view(self.shortcut_conv.bias.shape) 328 | return gc_s.data.view(-1) 329 | return None 330 | 331 | 332 | class netG(nn.Module): 333 | def __init__(self, nz, ngf, nc, smax, device): 334 | super(netG, self).__init__() 335 | self.nz = nz 336 | self.gate = torch.nn.Sigmoid() 337 | self.nc = nc 338 | self.ngf = ngf 339 | self.device = device 340 | self.tanh = nn.Tanh() 341 | self.smax = smax 342 | self.scalor = 1 343 | self.feats = ngf 344 | self.fc1 = Linear_extandable(self.nz, 4 * 4 * self.feats, bias=False, device=self.device) 345 | self.cap_fc0 = [ 4 * 4 * self.feats] 346 | self.shape_fc_1_out = [4 * 4 * self.feats] 347 | self.block1 = GeneratorBlock(self.feats, self.feats, smax, upsample=True, 348 | device=self.device) 349 | self.block2 = GeneratorBlock(self.feats, self.feats, smax, upsample=True, 350 | device=self.device) 351 | self.block3 = GeneratorBlock(self.feats, self.feats, smax, upsample=True, 352 | device=self.device) 353 | self.output_bns = torch.nn.ModuleList() 354 | self.output_bns.append(torch.nn.BatchNorm2d(self.scalor * self.feats)) 355 | self.apply(weights_init_g) 356 | self.last = torch.nn.ModuleList() 357 | self.efc1 = torch.nn.Embedding(10, self.scalor * 4 * 4 * self.feats) 358 | self.efc1.weight.data.fill_( 6/smax) 359 | 360 | def extand(self, t, extention, smax): 361 | print(extention) 362 | ws_fc = self.fc1.extand(0, [math.ceil(extention[0][0] / (self.nz)), extention[0][1]]) 363 | self.cap_fc0.append(self.fc1.weight.shape[0]) 364 | a = int(math.ceil(self.fc1.weight.shape[0] / 16.)) # desired output n_channels x 4 x 4 -> output of fc1 should be devidable by 16 365 | #print("a", a - self.block1.BatchNorm1s[t].weight.shape[0]) 366 | out_dim = self.block1.extand(t, a, extention[1], extention[2], extention[3], smax) 367 | out_dim = self.block2.extand(t, out_dim, extention[4], 368 | extention[5], extention[6], smax) 369 | out_dim = self.block3.extand(t, out_dim, extention[7], 370 | extention[8], extention[9], smax) 371 | # self.output_bn.extand(extention[6]) 372 | self.output_bns.append(torch.nn.BatchNorm2d(out_dim).cuda(self.device)) 373 | ws_0 = [self.block1.conv1.weight.shape, self.block1.conv2.weight.shape, self.block1.shortcut_conv.weight.shape] 374 | ws_1 = [self.block2.conv1.weight.shape, self.block2.conv2.weight.shape, self.block2.shortcut_conv.weight.shape] 375 | ws_2 = [self.block3.conv1.weight.shape, self.block3.conv2.weight.shape, self.block3.shortcut_conv.weight.shape] 376 | return [ws_fc] + ws_0 + ws_1 + ws_2 377 | 378 | def total_size_n_params(self): 379 | size = 0 380 | size += np.prod(self.efc1.weight.size()) 381 | size += np.prod(self.block1.conv1.weight.size()) + np.prod(self.block1.conv2.weight.size()) + \ 382 | np.prod(self.block1.shortcut_conv.weight.size()) 383 | size += np.prod(self.block2.conv1.weight.size()) + np.prod(self.block2.conv2.weight.size()) + \ 384 | np.prod(self.block2.shortcut_conv.weight.size()) 385 | size += np.prod(self.block3.conv1.weight.size()) + np.prod(self.block3.conv2.weight.size()) + \ 386 | np.prod(self.block3.shortcut_conv.weight.size()) 387 | return size 388 | 389 | def total_size(self): 390 | size = 0 391 | size += self.efc1.weight.shape[0] 392 | size += self.block1.conv1.weight.shape[0] + self.block1.conv2.weight.shape[0] + \ 393 | self.block1.shortcut_conv.weight.shape[0] 394 | size += self.block2.conv1.weight.shape[0] + self.block2.conv2.weight.shape[0] + \ 395 | self.block2.shortcut_conv.weight.shape[0] 396 | size += self.block3.conv1.weight.shape[0] + self.block3.conv2.weight.shape[0] + \ 397 | self.block3.shortcut_conv.weight.shape[0] 398 | return size 399 | 400 | 401 | def expand_embeddings(self, n_new_classes, t, mask): 402 | self.fc1.expand_embeddings(n_new_classes, t, mask[0]) 403 | prev_weight_shape_fc = [copy(self.fc1.prev_weight_shape)] 404 | self.fc1.prev_weight_shape = self.fc1.weight.shape 405 | 406 | # self.fc1.expand_embeddings(n_new_classes) 407 | prev_weight_shape_0 = self.block1.expand_embeddings(n_new_classes, t, mask[1:4]) 408 | # self.block1.expand_embeddings(n_new_classes) 409 | prev_weight_shape_1 = self.block2.expand_embeddings(n_new_classes, t, mask[4:7]) 410 | # self.block2.expand_embeddings(n_new_classes) 411 | prev_weight_shape_2 = self.block3.expand_embeddings(n_new_classes, t, mask[7:10]) 412 | return [prev_weight_shape_fc, prev_weight_shape_0, prev_weight_shape_1, prev_weight_shape_2] 413 | 414 | def forward(self, input_, task=0, s=1, past_generation=False, lables=None): 415 | # task = torch.autograd.Variable(torch.LongTensor([t]).cuda()) 416 | if not past_generation: 417 | gfc1 = self.mask(task, s=s) 418 | else: 419 | gfc1 = self.eval_masks(task) 420 | t = task 421 | x = input_ 422 | x = self.fc1(x.view(-1, self.nz), gfc1, self.nz, self.cap_fc0[t]) 423 | x = x.view(-1, int(x.shape[1] / 16.), 4, 4) 424 | x, prev_cap, masks1 = self.block1(x, x.shape[1], t, s, past_generation=past_generation) 425 | x, prev_cap, masks2 = self.block2(x, prev_cap[t], t, s, past_generation=past_generation) 426 | x, prev_cap, masks3 = self.block3(x, prev_cap[t], t, s, past_generation=past_generation) 427 | x = self.output_bns[t](x) 428 | x = nn.functional.relu(x, inplace=False) 429 | output = torch.stack( 430 | [self.tanh(self.last[c](x[i].unsqueeze(0), None, prev_cap[t], self.nc)) for i, c in 431 | enumerate(lables)]).squeeze(1) 432 | 433 | masks = [gfc1] + masks1 + masks2 + masks3 #+ masks4 434 | 435 | return output, masks 436 | 437 | def eval_masks(self, task): 438 | gc1 = self.fc1.ec_past.to_dense()[task].squeeze(0) 439 | return [gc1, None] 440 | 441 | def mask(self, t, s=1, test=False): 442 | t = torch.autograd.Variable(torch.LongTensor([0]).cuda(self.device)) 443 | gf0 = self.gate(s * self.fc1.ec(t)).view(self.fc1.weight.shape) 444 | gf0_b = None 445 | if self.fc1.ec_b is not None: 446 | gf0_b = self.gate(s * self.fc1.ec_b(t)).view(-1) 447 | return [gf0, gf0_b] 448 | 449 | def get_total_mask(self, t, s): 450 | m0 = self.mask(t, s) 451 | m1 = self.block1.mask(t, s) 452 | m2 = self.block2.mask(t, s) 453 | m3 = self.block3.mask(t, s) 454 | 455 | return [m0] + m1 + m2 + m3 456 | 457 | def get_total_mask_eval(self, t): 458 | m0 = self.eval_masks(t) 459 | m1 = self.block1.eval_masks(t) 460 | m2 = self.block2.eval_masks(t) 461 | m3 = self.block3.eval_masks(t) 462 | 463 | return [m0] + m1 + m2 + m3 464 | 465 | def get_view_for(self, n, masks): 466 | gfc1 = masks 467 | if n == 'fc1.weight': 468 | return gfc1[0].expand_as(self.fc1.weight) 469 | elif n == 'fc1.bias': 470 | return gfc1[1].data[:, :].view(-1) 471 | 472 | return None 473 | 474 | class DiscriminatorBlock(nn.Module): 475 | '''ResNet-style block for the discriminator model.''' 476 | def __init__(self, in_chans, out_chans, downsample=False, first=False): 477 | super(DiscriminatorBlock, self).__init__() 478 | 479 | self.downsample = downsample 480 | self.first = first 481 | 482 | if in_chans != out_chans: 483 | self.shortcut_conv = nn.Conv2d(in_chans, out_chans, kernel_size=1) 484 | else: 485 | self.shortcut_conv = None 486 | self.conv1 = nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1) 487 | self.bn = nn.BatchNorm2d(out_chans) 488 | self.conv2 = nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1) 489 | 490 | def forward(self, *inputs): 491 | x = inputs[0] 492 | 493 | if self.downsample: 494 | shortcut = avg_pool2d(x) 495 | else: 496 | shortcut = x 497 | 498 | if self.shortcut_conv is not None: 499 | shortcut = self.shortcut_conv(shortcut) 500 | 501 | if not self.first: 502 | x = nn.functional.relu(x, inplace=False) 503 | x = self.conv1(x) 504 | x = nn.functional.relu(x, inplace=False) 505 | s = self.bn(x) 506 | x = self.conv2(x) 507 | if self.downsample: 508 | x = avg_pool2d(x) 509 | 510 | return x + shortcut 511 | 512 | class netD(nn.Module): 513 | def __init__(self, feature_size, n_classes, device): 514 | # Network architecture 515 | super(netD, self).__init__() 516 | self.device = device 517 | self.feats = feature_size 518 | self.feature_extractor = resnet18() 519 | self.feature_extractor.fc = \ 520 | nn.Linear(self.feature_extractor.fc.in_features, feature_size) 521 | self.bn = nn.BatchNorm1d(feature_size, momentum=0.01) 522 | self.ReLU = nn.ReLU() 523 | self.aux_linear = nn.Linear(feature_size, n_classes, bias=False) 524 | self.disc_linear = nn.Linear(feature_size, 1) 525 | self.n_classes = n_classes 526 | self.softmax = nn.Softmax(dim=1) 527 | self.sigmoid = nn.Sigmoid() 528 | 529 | def forward(self, x): 530 | x = self.feature_extractor(x) 531 | x = self.bn(x) 532 | x = self.ReLU(x) 533 | c = self.aux_linear(x) 534 | s = self.disc_linear(x) 535 | return s.view(-1), c 536 | -------------------------------------------------------------------------------- /networks/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet18/34/50/101/152 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.autograd import Variable 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1, shortcut=None): 17 | super(BasicBlock, self).__init__() 18 | self.layers = nn.Sequential( 19 | conv3x3(in_planes, planes, stride), 20 | nn.BatchNorm2d(planes), 21 | nn.ReLU(True), 22 | conv3x3(planes, planes), 23 | nn.BatchNorm2d(planes), 24 | ) 25 | self.shortcut = shortcut 26 | 27 | def forward(self, x): 28 | residual = x 29 | y = self.layers(x) 30 | if self.shortcut: 31 | residual = self.shortcut(x) 32 | y += residual 33 | y = F.relu(y) 34 | return y 35 | 36 | 37 | class Bottleneck(nn.Module): 38 | expansion = 4 39 | 40 | def __init__(self, in_planes, planes, stride=1, shortcut=None): 41 | super(Bottleneck, self).__init__() 42 | self.layers = nn.Sequential( 43 | nn.Conv2d(in_planes, planes, kernel_size=1, bias=False), 44 | nn.BatchNorm2d(planes), 45 | nn.ReLU(True), 46 | nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False), 47 | nn.BatchNorm2d(planes), 48 | nn.ReLU(True), 49 | nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False), 50 | nn.BatchNorm2d(planes * 4), 51 | ) 52 | self.shortcut = shortcut 53 | 54 | def forward(self, x): 55 | residual = x 56 | y = self.layers(x) 57 | if self.shortcut: 58 | residual = self.shortcut(x) 59 | y += residual 60 | y = F.relu(y) 61 | return y 62 | 63 | 64 | class ResNet(nn.Module): 65 | def __init__(self, block, nblocks, num_classes=10): 66 | super(ResNet, self).__init__() 67 | self.in_planes = 64 68 | self.pre_layers = nn.Sequential( 69 | conv3x3(3,64), 70 | nn.BatchNorm2d(64), 71 | nn.ReLU(True), 72 | ) 73 | self.layer1 = self._make_layer(block, 64, nblocks[0]) 74 | self.layer2 = self._make_layer(block, 128, nblocks[1], stride=2) 75 | self.layer3 = self._make_layer(block, 256, nblocks[2], stride=2) 76 | self.layer4 = self._make_layer(block, 512, nblocks[3], stride=2) 77 | self.avgpool = nn.AvgPool2d(4) 78 | self.fc = nn.Linear(512*block.expansion, num_classes) 79 | 80 | def _make_layer(self, block, planes, nblocks, stride=1): 81 | shortcut = None 82 | if stride != 1 or self.in_planes != planes * block.expansion: 83 | shortcut = nn.Sequential( 84 | nn.Conv2d(self.in_planes, planes * block.expansion, 85 | kernel_size=1, stride=stride, bias=False), 86 | nn.BatchNorm2d(planes * block.expansion), 87 | ) 88 | layers = [] 89 | layers.append(block(self.in_planes, planes, stride, shortcut)) 90 | self.in_planes = planes * block.expansion 91 | for i in range(1, nblocks): 92 | layers.append(block(self.in_planes, planes)) 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | x = self.pre_layers(x) 97 | x = self.layer1(x) 98 | x = self.layer2(x) 99 | x = self.layer3(x) 100 | x = self.layer4(x) 101 | x = self.avgpool(x) 102 | x = x.view(x.size(0), -1) 103 | #print(x.shape) 104 | x = self.fc(x) 105 | return x 106 | 107 | 108 | def resnet18(): 109 | return ResNet(BasicBlock, [2,2,2,2]) 110 | 111 | def resnet34(): 112 | return ResNet(BasicBlock, [3,4,6,3]) 113 | 114 | def resnet50(): 115 | return ResNet(Bottleneck, [3,4,6,3]) 116 | 117 | def resnet101(): 118 | return ResNet(Bottleneck, [3,4,23,3]) 119 | 120 | def resnet152(): 121 | return ResNet(Bottleneck, [3,8,36,3]) 122 | 123 | # net = ResNet(BasicBlock, [2,2,2,2]) 124 | # x = torch.randn(1,3,32,32) 125 | # y = net(Variable(x)) 126 | # print(y.size()) -------------------------------------------------------------------------------- /outputs/.gitignore: -------------------------------------------------------------------------------- 1 | [^.]* -------------------------------------------------------------------------------- /requierements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.6.1 2 | appdirs==1.4.3 3 | asn1crypto==0.24.0 4 | astor==0.7.1 5 | attrs==18.1.0 6 | Automat==0.3.0 7 | autovizwidget==0.12.5 8 | backcall==0.1.0 9 | biwrap==0.1.6 10 | bleach==2.1.3 11 | bokeh==0.13.0 12 | boto3==1.7.60 13 | botocore==1.10.60 14 | certifi==2018.11.29 15 | cffi==1.11.5 16 | characteristic==14.3.0 17 | chardet==3.0.4 18 | constantly==15.1.0 19 | cryptography==2.4.2 20 | cycler==0.10.0 21 | Cython==0.28.4 22 | decorator==4.3.0 23 | docutils==0.14 24 | easydict==1.9 25 | entrypoints==0.2.3 26 | environment-kernels==1.1.1 27 | future==0.17.1 28 | gast==0.2.0 29 | grpcio==1.16.0 30 | h5py==2.8.0 31 | hdijupyterutils==0.12.5 32 | html5lib==1.0.1 33 | hyperlink==17.3.1 34 | hyperopt==0.1.2 35 | idna==2.7 36 | incremental==17.5.0 37 | ipykernel==4.8.2 38 | ipython==6.4.0 39 | ipython-genutils==0.2.0 40 | ipywidgets==7.3.0 41 | jedi==0.12.1 42 | Jinja2==2.10 43 | jmespath==0.9.3 44 | jsonschema==2.6.0 45 | jupyter-client==5.2.3 46 | jupyter-core==4.4.0 47 | Keras==2.2.4 48 | Keras-Applications==1.0.6 49 | Keras-Preprocessing==1.0.5 50 | kiwisolver==1.0.1 51 | Markdown==3.0.1 52 | MarkupSafe==1.0 53 | matplotlib==2.2.2 54 | mistune==0.8.3 55 | mkl-fft==1.0.2 56 | mkl-random==1.0.1 57 | nb-conda==2.2.1 58 | nb-conda-kernels==2.1.1 59 | nbconvert==5.3.1 60 | nbformat==4.4.0 61 | networkx==2.3 62 | notebook==5.6.0 63 | numpy==1.14.5 64 | olefile==0.45.1 65 | onnx==1.2.1 66 | packaging==17.1 67 | pandas==0.22.0 68 | pandocfilters==1.4.2 69 | parso==0.3.0 70 | pexpect==4.6.0 71 | pickleshare==0.7.4 72 | Pillow==5.1.0 73 | plotly==2.7.0 74 | prometheus-client==0.2.0 75 | prompt-toolkit==1.0.15 76 | protobuf==3.6.1 77 | psycopg2==2.7.5 78 | ptyprocess==0.6.0 79 | py4j==0.10.4 80 | pyasn1==0.4.3 81 | pyasn1-modules==0.2.1 82 | pycparser==2.18 83 | pygal==2.4.0 84 | Pygments==2.2.0 85 | PyHamcrest==1.9.0 86 | pykerberos==1.1.14 87 | pymongo==3.8.0 88 | pyOpenSSL==18.0.0 89 | pyparsing==2.2.0 90 | pypng==0.0.18 91 | PySocks==1.6.8 92 | pyspark==2.2.1 93 | python-dateutil==2.7.3 94 | pytz==2018.5 95 | PyYAML==3.12 96 | pyzmq==17.1.0 97 | requests==2.19.1 98 | requests-kerberos==0.12.0 99 | s3transfer==0.1.13 100 | scikit-learn==0.20.0 101 | scipy==1.1.0 102 | Send2Trash==1.5.0 103 | service-identity==17.0.0 104 | simplegeneric==0.8.1 105 | six==1.11.0 106 | sklearn==0.0 107 | sparkmagic==0.12.5 108 | SQLAlchemy==1.2.10 109 | style==1.1.0 110 | tensorboard==1.12.0 111 | tensorflow==1.12.0 112 | tensorflow-plot==0.3.0 113 | termcolor==1.1.0 114 | terminado==0.8.1 115 | testpath==0.3.1 116 | torch==1.0.0 117 | torchvision==0.2.1 118 | tornado==5.1 119 | tqdm==4.28.1 120 | traitlets==4.3.2 121 | Twisted==17.5.0 122 | typing==3.6.4 123 | typing-extensions==3.6.5 124 | update==0.0.1 125 | urllib3==1.23 126 | wcwidth==0.1.7 127 | webencodings==0.5 128 | Werkzeug==0.14.1 129 | widgetsnbextension==3.3.0 130 | zope.interface==4.5.0 -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #Copyright 2019 SAP SE 2 | #Licensed under the Apache License, Version 2.0 (the "License"); 3 | #you may not use this file except in compliance with the License. 4 | #You may obtain a copy of the License at 5 | 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | 8 | #Unless required by applicable law or agreed to in writing, software 9 | #distributed under the License is distributed on an "AS IS" BASIS, 10 | #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | #See the License for the specific language governing permissions and 12 | #limitations under the License. 13 | from __future__ import print_function 14 | import time,datetime,argparse,os,random 15 | import shutil 16 | import torch.utils.data 17 | from cfg.load_config import opt, cfg_from_file 18 | import numpy as np 19 | 20 | ts = time.time() 21 | 22 | # Arguments 23 | parser=argparse.ArgumentParser(description='xxx') 24 | parser.add_argument('--dataset',default='mnist',type=str,required=False, choices=['mnist','svhn'],help='Dataset name') 25 | parser.add_argument('--method',default='DGMw',type=str,required=False, choices=['DGMa','DGMw'], help='Method to run.') 26 | #parser.add_argument('--cfg_file',default=None,type=str,required=False, help='Path to the configuration file') 27 | cfg=parser.parse_args() 28 | if cfg.method =="DGMw": 29 | if cfg.dataset == "mnist": 30 | cfg_file = 'cfg/cfg_mnist_dgmw.yml' 31 | cfg_from_file(cfg_file) 32 | elif cfg.dataset == "svhn": 33 | cfg_file = 'cfg/cfg_svhn_dgmw.yml' 34 | cfg_from_file(cfg_file) 35 | elif cfg.method =="DGMa": 36 | if cfg.dataset == "mnist": 37 | cfg_file = 'cfg/cfg_mnist_dgma.yml' 38 | cfg_from_file(cfg_file) 39 | elif cfg.dataset == "svhn": 40 | cfg_file = 'cfg/cfg_svhn_dgma.yml' 41 | cfg_from_file(cfg_file) 42 | print(opt) 43 | 44 | ####################################################################################################################### 45 | opt.device = torch.device("cuda:" + str(opt.device) if torch.cuda.is_available() else "cpu") 46 | if torch.cuda.is_available(): 47 | torch.cuda.set_device(opt.device) 48 | print(opt) 49 | 50 | 51 | try: 52 | os.makedirs(opt.outf) 53 | except OSError: 54 | pass 55 | try: 56 | os.makedirs(opt.outf_models) 57 | except OSError: 58 | pass 59 | try: 60 | os.makedirs(opt.outf + '/mask_histo') 61 | except: 62 | pass 63 | 64 | 65 | if opt.dataset=="mnist": 66 | from dataloaders import split_MNIST as dataloader 67 | elif opt.dataset=="svhn": 68 | from dataloaders import split_SVHN as dataloader 69 | if opt.method == "DGMw": 70 | from networks import net_DGMw as model 71 | from approaches import DGMw as approach 72 | elif opt.method == "DGMa": 73 | from networks import net_DGMa as model 74 | from approaches import DGMa as approach 75 | 76 | 77 | 78 | if opt.manualSeed is None: 79 | opt.manualSeed = random.randint(1, 10000) 80 | print("Random Seed: ", opt.manualSeed) 81 | random.seed(opt.manualSeed) 82 | torch.manual_seed(opt.manualSeed) 83 | np.random.seed(opt.manualSeed) 84 | 85 | if torch.cuda.is_available(): 86 | torch.cuda.manual_seed_all(opt.manualSeed) 87 | 88 | 89 | print('Load data...') 90 | data, taskcla, inputsize = dataloader.get(seed=opt.manualSeed,data_root=opt.dataroot+str(opt.imageSize), n_classes=1, imageSize=opt.imageSize) 91 | print('Input size =', inputsize, '\nTask info =', taskcla) 92 | for t in range(10): 93 | data[t]['train']['y'].data.fill_(t) 94 | data[t]['test']['y'].data.fill_(t) 95 | data[t]['valid']['y'].data.fill_(t) 96 | 97 | nz = int(opt.nz) 98 | ngf = int(opt.ngf) 99 | ndf = int(opt.ndf) 100 | nb_label = 10 101 | if opt.dataset == 'mnist': 102 | nc = 1 103 | elif opt.dataset == 'svhn': 104 | nc = 3 105 | 106 | #classes are added one by one, we innitialize G with one head output 107 | netG = model.netG(nz, ngf, nc, opt.smax_g, n_classes=1) 108 | print(netG) 109 | netD = model.netD(ndf, nc) 110 | print(netD) 111 | 112 | 113 | log_dir = opt.log_dir + datetime.datetime.fromtimestamp(ts).strftime('%Y_%m_%d_%H_%M_%S') 114 | if os.path.exists(log_dir): 115 | shutil.rmtree(log_dir) 116 | os.makedirs(log_dir) 117 | 118 | appr = approach.App(model, netG, netD, log_dir, opt.outf, niter=opt.niter, batchSize=opt.batchSize, 119 | imageSize=opt.imageSize, nz=int(opt.nz), nb_label=nb_label, cuda=torch.cuda.is_available(), beta1=opt.beta1, 120 | lr_D=opt.lr_D, lr_G=opt.lr_G, lamb_G=opt.lamb_G, 121 | reinit_D=opt.reinit_D, lambda_adv=opt.lambda_adv, lambda_wassersten=opt.lambda_wasserstein, dataset=opt.dataset, store_model = opt.store_models) 122 | 123 | 124 | for t in range(10): 125 | test_acc_task, conf_matrixes_task, mask_G = appr.train(data, t, smax_g=opt.smax_g,use_aux_G=opt.aux_G) 126 | -------------------------------------------------------------------------------- /run_DGMw_imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 SAP SE 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | from __future__ import print_function 14 | from networks import net_DGMw_imnet as model 15 | from approaches import DGMw_imnet as approach 16 | import os 17 | import random 18 | import argparse 19 | import shutil 20 | import time 21 | import datetime 22 | import importlib 23 | import numpy as np 24 | from cfg.load_config import opt, cfg_from_file 25 | import torch.backends.cudnn as cudnn 26 | import torch.utils.data 27 | import torchvision.transforms as transforms 28 | from utils.folder import ImageFolder 29 | 30 | ts = time.time() 31 | 32 | # Arguments 33 | parser = argparse.ArgumentParser(description='xxx') 34 | parser.add_argument( 35 | '--dataset', 36 | default='imnet', 37 | type=str, 38 | required=False, 39 | choices=['imagenet'], 40 | help='Dataset name') 41 | parser.add_argument( 42 | '--cfg_file', 43 | default=None, 44 | type=str, 45 | required=False, 46 | help='Path to the configuration file') 47 | cfg = parser.parse_args() 48 | if cfg.cfg_file is not None: 49 | try: 50 | cfg_from_file(cfg.cfg_file) 51 | except FileNotFoundError: 52 | if cfg.dataset == "imnet": 53 | cfg_file = 'cfg/cfg_imnet_dgmw.yml' 54 | cfg_from_file(cfg_file) 55 | else: 56 | if cfg.dataset == "imnet": 57 | cfg_file = 'cfg/cfg_imnet_dgmw.yml' 58 | cfg_from_file(cfg_file) 59 | 60 | print(opt) 61 | try: 62 | os.makedirs(opt.outf) 63 | except OSError: 64 | pass 65 | 66 | try: 67 | os.makedirs(opt.outf_models) 68 | except OSError: 69 | pass 70 | 71 | 72 | if opt.manualSeed is None: 73 | opt.manualSeed = random.randint(1, 10000) 74 | print("Random Seed: ", opt.manualSeed) 75 | random.seed(opt.manualSeed) 76 | torch.manual_seed(opt.manualSeed) 77 | np.random.seed(opt.manualSeed) 78 | if opt.cuda: 79 | torch.cuda.manual_seed_all(opt.manualSeed) 80 | 81 | cudnn.benchmark = True 82 | if torch.cuda.is_available() and not opt.cuda: 83 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 84 | cuda1 = torch.device(opt.device_D) 85 | cuda2 = torch.device(opt.device_G) 86 | 87 | ngpu = int(1) 88 | nz = int(opt.nz) 89 | ngf = int(opt.ngf) 90 | ndf = int(opt.ndf) 91 | num_classes = int(100) 92 | nc = 3 93 | 94 | # resnet18 95 | netD = model.netD(2048, n_classes=10, device=cuda1) 96 | netG = model.netG(nz, ngf, nc, opt.smax_g, device=cuda2) 97 | print(netD) 98 | print(netG) 99 | ts = time.time() 100 | log_dir = opt.log_dir + \ 101 | datetime.datetime.fromtimestamp(ts).strftime('%Y_%m_%d_%H_%M_%S') 102 | importlib.reload(approach) 103 | if os.path.exists(log_dir): 104 | shutil.rmtree(log_dir) 105 | os.makedirs(log_dir) 106 | 107 | #idx = [1,15,29,45,59,65,81,89,90,99] 108 | idx = opt.class_idx_imnet 109 | appr = approach.App( 110 | model, 111 | netG, 112 | netD, 113 | log_dir, 114 | opt.outf, 115 | niter=opt.niter, 116 | batchSize=opt.batchSize, 117 | imageSize=opt.imageSize, 118 | nz=int( 119 | opt.nz), 120 | nb_label=num_classes, 121 | cuda=True, 122 | beta1=opt.beta1, 123 | lr_D=opt.lr_D, 124 | lr_G=opt.lr_G, 125 | lamb_G=opt.lamb_G, 126 | reinit_D=opt.reinit_D, 127 | lambd_adv=opt.lambda_adv, 128 | lambda_wassersten=opt.lambda_wasserstein, 129 | dataroot_test=opt.dataroot_val, 130 | dataroot=opt.dataroot, 131 | store_model=opt.store_models, 132 | out_models=opt.outf_models, 133 | calc_fid_imnet=opt.calc_fid_imnet, 134 | class_idx=idx) # , gpu_tracker=gpu_tracker) 135 | appr.writer.text_summary("opt", str(opt)) 136 | 137 | 138 | test_acc_tasks = [] 139 | conf_matrixes = [] 140 | for t in range(10): 141 | idx_ = [i + (t * 100) for i in idx] 142 | dataset = ImageFolder( 143 | root=opt.dataroot, 144 | transform=transforms.Compose([ 145 | transforms.Resize(opt.imageSize), 146 | transforms.CenterCrop(opt.imageSize), 147 | transforms.ToTensor(), 148 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 149 | ]), 150 | classes_idx=(idx_) 151 | ) 152 | dataloader = torch.utils.data.DataLoader( 153 | dataset, 154 | batch_size=opt.batchSize, 155 | shuffle=True, 156 | num_workers=int( 157 | opt.workers)) 158 | test_acc_task, conf_matrixes_task, mask_G = appr.train( 159 | dataloader, dataset, t, smax_g=opt.smax_g, use_aux_G=opt.aux_G) 160 | test_acc_tasks.append(test_acc_task) 161 | conf_matrixes.append(conf_matrixes_task) 162 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAP-archive/machine-learning-dgm/78786f0d9469cba201ad0108e4af2387574dc7c0/utils/__init__.py -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | 5 | import numpy as np 6 | from PIL import Image 7 | import scipy.linalg 8 | 9 | import chainer 10 | import chainer.cuda 11 | from chainer import Variable 12 | from chainer import serializers 13 | import chainer.functions as F 14 | 15 | sys.path.append(os.path.dirname(__file__)) 16 | from inception.inception_score import inception_score, Inception 17 | 18 | 19 | def sample_generate_light(gen, dst, rows=5, cols=5, seed=0): 20 | @chainer.training.make_extension() 21 | def make_image(trainer): 22 | np.random.seed(seed) 23 | n_images = rows * cols 24 | xp = gen.xp 25 | z = Variable(xp.asarray(gen.make_hidden(n_images))) 26 | with chainer.using_config('train', False), chainer.using_config('enable_backprop', False): 27 | x = gen(z) 28 | x = chainer.cuda.to_cpu(x.data) 29 | np.random.seed() 30 | 31 | x = np.asarray(np.clip(x * 127.5 + 127.5, 0.0, 255.0), dtype=np.uint8) 32 | _, _, H, W = x.shape 33 | x = x.reshape((rows, cols, 3, H, W)) 34 | x = x.transpose(0, 3, 1, 4, 2) 35 | x = x.reshape((rows * H, cols * W, 3)) 36 | 37 | preview_dir = '{}/preview'.format(dst) 38 | preview_path = preview_dir + '/image_latest.png' 39 | if not os.path.exists(preview_dir): 40 | os.makedirs(preview_dir) 41 | Image.fromarray(x).save(preview_path) 42 | 43 | return make_image 44 | 45 | 46 | def sample_generate(gen, dst, rows=10, cols=10, seed=0): 47 | """Visualization of rows*cols images randomly generated by the generator.""" 48 | @chainer.training.make_extension() 49 | def make_image(trainer): 50 | np.random.seed(seed) 51 | n_images = rows * cols 52 | xp = gen.xp 53 | z = Variable(xp.asarray(gen.make_hidden(n_images))) 54 | with chainer.using_config('train', False), chainer.using_config('enable_backprop', False): 55 | x = gen(z) 56 | x = chainer.cuda.to_cpu(x.data) 57 | np.random.seed() 58 | 59 | x = np.asarray(np.clip(x * 127.5 + 127.5, 0.0, 255.0), dtype=np.uint8) 60 | _, _, h, w = x.shape 61 | x = x.reshape((rows, cols, 3, h, w)) 62 | x = x.transpose(0, 3, 1, 4, 2) 63 | x = x.reshape((rows * h, cols * w, 3)) 64 | 65 | preview_dir = '{}/preview'.format(dst) 66 | preview_path = preview_dir + '/image{:0>8}.png'.format(trainer.updater.iteration) 67 | if not os.path.exists(preview_dir): 68 | os.makedirs(preview_dir) 69 | Image.fromarray(x).save(preview_path) 70 | 71 | return make_image 72 | 73 | 74 | def load_inception_model(): 75 | infile = "%s/inception/inception_score.model"%os.path.dirname(__file__) 76 | model = Inception() 77 | serializers.load_hdf5(infile, model) 78 | model.to_gpu() 79 | return model 80 | 81 | 82 | def calc_inception(gen, batchsize=100): 83 | @chainer.training.make_extension() 84 | def evaluation(trainer): 85 | model = load_inception_model() 86 | 87 | ims = [] 88 | xp = gen.xp 89 | 90 | n_ims = 50000 91 | for i in range(0, n_ims, batchsize): 92 | z = Variable(xp.asarray(gen.make_hidden(batchsize))) 93 | with chainer.using_config('train', False), chainer.using_config('enable_backprop', False): 94 | x = gen(z) 95 | x = chainer.cuda.to_cpu(x.data) 96 | x = np.asarray(np.clip(x * 127.5 + 127.5, 0.0, 255.0), dtype=np.uint8) 97 | ims.append(x) 98 | ims = np.asarray(ims) 99 | _, _, _, h, w = ims.shape 100 | ims = ims.reshape((n_ims, 3, h, w)).astype("f") 101 | 102 | mean, std = inception_score(model, ims) 103 | 104 | chainer.reporter.report({ 105 | 'inception_mean': mean, 106 | 'inception_std': std 107 | }) 108 | 109 | return evaluation 110 | 111 | 112 | def get_mean_cov(model, ims, batch_size=100): 113 | n, c, w, h = ims.shape 114 | n_batches = int(math.ceil(float(n) / float(batch_size))) 115 | 116 | xp = model.xp 117 | 118 | print('Batch size:', batch_size) 119 | print('Total number of images:', n) 120 | print('Total number of batches:', n_batches) 121 | 122 | ys = xp.empty((n, 2048), dtype=xp.float32) 123 | 124 | for i in range(n_batches): 125 | print('Running batch', i + 1, '/', n_batches, '...') 126 | batch_start = (i * batch_size) 127 | batch_end = min((i + 1) * batch_size, n) 128 | 129 | ims_batch = ims[batch_start:batch_end] 130 | ims_batch = xp.asarray(ims_batch) # To GPU if using CuPy 131 | ims_batch = Variable(ims_batch) 132 | 133 | # Resize image to the shape expected by the inception module 134 | if (w, h) != (299, 299): 135 | ims_batch = F.resize_images(ims_batch, (299, 299)) # bilinear 136 | 137 | # Feed images to the inception module to get the features 138 | with chainer.using_config('train', False), chainer.using_config('enable_backprop', False): 139 | y = model(ims_batch, get_feature=True) 140 | ys[batch_start:batch_end] = y.data 141 | 142 | mean = chainer.cuda.to_cpu(xp.mean(ys, axis=0)) 143 | # cov = F.cross_covariance(ys, ys, reduce="no").data.get() 144 | cov = np.cov(chainer.cuda.to_cpu(ys).T) 145 | 146 | return mean, cov 147 | 148 | def FID(m0,c0,m1,c1): 149 | ret = 0 150 | ret += np.sum((m0-m1)**2) 151 | ret += np.trace(c0 + c1 - 2.0*scipy.linalg.sqrtm(np.dot(c0, c1))) 152 | return np.real(ret) 153 | 154 | def calc_FID(gen, batchsize=100, stat_file="%s/cifar-10-fid.npz"%os.path.dirname(__file__)): 155 | """Frechet Inception Distance proposed by https://arxiv.org/abs/1706.08500""" 156 | @chainer.training.make_extension() 157 | def evaluation(trainer): 158 | model = load_inception_model() 159 | stat = np.load(stat_file) 160 | 161 | n_ims = 10000 162 | xp = gen.xp 163 | xs = [] 164 | for i in range(0, n_ims, batchsize): 165 | z = Variable(xp.asarray(gen.make_hidden(batchsize))) 166 | with chainer.using_config('train', False), chainer.using_config('enable_backprop', False): 167 | x = gen(z) 168 | x = chainer.cuda.to_cpu(x.data) 169 | x = np.asarray(np.clip(x * 127.5 + 127.5, 0.0, 255.0), dtype="f") 170 | xs.append(x) 171 | xs = np.asarray(xs) 172 | _, _, _, h, w = xs.shape 173 | 174 | with chainer.using_config('train', False), chainer.using_config('enable_backprop', False): 175 | mean, cov = get_mean_cov(model, np.asarray(xs).reshape((-1, 3, h, w))) 176 | fid = FID(stat["mean"], stat["cov"], mean, cov) 177 | 178 | chainer.reporter.report({ 179 | 'FID': fid, 180 | }) 181 | 182 | return evaluation -------------------------------------------------------------------------------- /utils/folder.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | 7 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] 8 | 9 | 10 | def is_image_file(filename): 11 | """Checks if a file is an image. 12 | 13 | Args: 14 | filename (string): path to a file 15 | 16 | Returns: 17 | bool: True if the filename ends with a known image extension 18 | """ 19 | filename_lower = filename.lower() 20 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) 21 | 22 | 23 | def find_classes(dir, classes_idx=None): 24 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 25 | classes.sort() 26 | #print(classes) 27 | if classes_idx is not None: 28 | #assert type(classes_idx) == tuple 29 | if type(classes_idx) == tuple: 30 | start, end = classes_idx 31 | classes = classes[start:end] 32 | else: 33 | classes = [classes[i] for i in classes_idx] 34 | print(classes) 35 | class_to_idx = {classes[i]: i for i in range(len(classes))} 36 | return classes, class_to_idx 37 | 38 | 39 | def make_dataset(dir, class_to_idx): 40 | images = [] 41 | dir = os.path.expanduser(dir) 42 | for target in sorted(os.listdir(dir)): 43 | if target not in class_to_idx: 44 | continue 45 | d = os.path.join(dir, target) 46 | if not os.path.isdir(d): 47 | continue 48 | 49 | for root, _, fnames in sorted(os.walk(d)): 50 | for fname in sorted(fnames): 51 | if is_image_file(fname): 52 | path = os.path.join(root, fname) 53 | item = (path, class_to_idx[target]) 54 | images.append(item) 55 | 56 | return images 57 | 58 | 59 | def pil_loader(path): 60 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 61 | with open(path, 'rb') as f: 62 | with Image.open(f) as img: 63 | return img.convert('RGB') 64 | 65 | 66 | def accimage_loader(path): 67 | import accimage 68 | try: 69 | return accimage.Image(path) 70 | except IOError: 71 | # Potentially a decoding problem, fall back to PIL.Image 72 | return pil_loader(path) 73 | 74 | 75 | def default_loader(path): 76 | from torchvision import get_image_backend 77 | if get_image_backend() == 'accimage': 78 | return accimage_loader(path) 79 | else: 80 | return pil_loader(path) 81 | 82 | 83 | class ImageFolder(data.Dataset): 84 | """A generic data loader where the images are arranged in this way: :: 85 | 86 | root/dog/xxx.png 87 | root/dog/xxy.png 88 | root/dog/xxz.png 89 | 90 | root/cat/123.png 91 | root/cat/nsdf3.png 92 | root/cat/asd932_.png 93 | 94 | Args: 95 | root (string): Root directory path. 96 | transform (callable, optional): A function/transform that takes in an PIL image 97 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 98 | target_transform (callable, optional): A function/transform that takes in the 99 | target and transforms it. 100 | loader (callable, optional): A function to load an image given its path. 101 | 102 | Attributes: 103 | classes (list): List of the class names. 104 | class_to_idx (dict): Dict with items (class_name, class_index). 105 | imgs (list): List of (image path, class_index) tuples 106 | """ 107 | 108 | def __init__(self, root, transform=None, target_transform=None, 109 | loader=default_loader, classes_idx=None): 110 | self.classes_idx = classes_idx 111 | classes, class_to_idx = find_classes(root, self.classes_idx) 112 | imgs = make_dataset(root, class_to_idx) 113 | if len(imgs) == 0: 114 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 115 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 116 | 117 | self.root = root 118 | self.imgs = imgs 119 | self.classes = classes 120 | self.class_to_idx = class_to_idx 121 | self.transform = transform 122 | self.target_transform = target_transform 123 | self.loader = loader 124 | 125 | def __getitem__(self, index): 126 | """ 127 | Args: 128 | index (int): Index 129 | 130 | Returns: 131 | tuple: (image, target) where target is class_index of the target class. 132 | """ 133 | path, target = self.imgs[index] 134 | img = self.loader(path) 135 | if self.transform is not None: 136 | img = self.transform(img) 137 | if self.target_transform is not None: 138 | target = self.target_transform(target) 139 | 140 | return img, target 141 | 142 | def __len__(self): 143 | return len(self.imgs) 144 | -------------------------------------------------------------------------------- /utils/inception.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torchvision import models 4 | 5 | 6 | class InceptionV3(nn.Module): 7 | """Pretrained InceptionV3 network returning feature maps""" 8 | 9 | # Index of default block of inception to return, 10 | # corresponds to output of final average pooling 11 | DEFAULT_BLOCK_INDEX = 3 12 | 13 | # Maps feature dimensionality to their output blocks indices 14 | BLOCK_INDEX_BY_DIM = { 15 | 64: 0, # First max pooling features 16 | 192: 1, # Second max pooling featurs 17 | 768: 2, # Pre-aux classifier features 18 | 2048: 3 # Final average pooling features 19 | } 20 | 21 | def __init__(self, 22 | output_blocks=[DEFAULT_BLOCK_INDEX], 23 | resize_input=True, 24 | normalize_input=True, 25 | requires_grad=False): 26 | """Build pretrained InceptionV3 27 | 28 | Parameters 29 | ---------- 30 | output_blocks : list of int 31 | Indices of blocks to return features of. Possible values are: 32 | - 0: corresponds to output of first max pooling 33 | - 1: corresponds to output of second max pooling 34 | - 2: corresponds to output which is fed to aux classifier 35 | - 3: corresponds to output of final average pooling 36 | resize_input : bool 37 | If true, bilinearly resizes input to width and height 299 before 38 | feeding input to model. As the network without fully connected 39 | layers is fully convolutional, it should be able to handle inputs 40 | of arbitrary size, so resizing might not be strictly needed 41 | normalize_input : bool 42 | If true, normalizes the input to the statistics the pretrained 43 | Inception network expects 44 | requires_grad : bool 45 | If true, parameters of the model require gradient. Possibly useful 46 | for finetuning the network 47 | """ 48 | super(InceptionV3, self).__init__() 49 | 50 | self.resize_input = resize_input 51 | self.normalize_input = normalize_input 52 | self.output_blocks = sorted(output_blocks) 53 | self.last_needed_block = max(output_blocks) 54 | 55 | assert self.last_needed_block <= 3, \ 56 | 'Last possible output block index is 3' 57 | 58 | self.blocks = nn.ModuleList() 59 | 60 | inception = models.inception_v3(pretrained=True) 61 | 62 | # Block 0: input to maxpool1 63 | block0 = [ 64 | inception.Conv2d_1a_3x3, 65 | inception.Conv2d_2a_3x3, 66 | inception.Conv2d_2b_3x3, 67 | nn.MaxPool2d(kernel_size=3, stride=2) 68 | ] 69 | self.blocks.append(nn.Sequential(*block0)) 70 | 71 | # Block 1: maxpool1 to maxpool2 72 | if self.last_needed_block >= 1: 73 | block1 = [ 74 | inception.Conv2d_3b_1x1, 75 | inception.Conv2d_4a_3x3, 76 | nn.MaxPool2d(kernel_size=3, stride=2) 77 | ] 78 | self.blocks.append(nn.Sequential(*block1)) 79 | 80 | # Block 2: maxpool2 to aux classifier 81 | if self.last_needed_block >= 2: 82 | block2 = [ 83 | inception.Mixed_5b, 84 | inception.Mixed_5c, 85 | inception.Mixed_5d, 86 | inception.Mixed_6a, 87 | inception.Mixed_6b, 88 | inception.Mixed_6c, 89 | inception.Mixed_6d, 90 | inception.Mixed_6e, 91 | ] 92 | self.blocks.append(nn.Sequential(*block2)) 93 | 94 | # Block 3: aux classifier to final avgpool 95 | if self.last_needed_block >= 3: 96 | block3 = [ 97 | inception.Mixed_7a, 98 | inception.Mixed_7b, 99 | inception.Mixed_7c, 100 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 101 | ] 102 | self.blocks.append(nn.Sequential(*block3)) 103 | 104 | for param in self.parameters(): 105 | param.requires_grad = requires_grad 106 | 107 | def forward(self, inp): 108 | """Get Inception feature maps 109 | 110 | Parameters 111 | ---------- 112 | inp : torch.autograd.Variable 113 | Input tensor of shape Bx3xHxW. Values are expected to be in 114 | range (0, 1) 115 | 116 | Returns 117 | ------- 118 | List of torch.autograd.Variable, corresponding to the selected output 119 | block, sorted ascending by index 120 | """ 121 | outp = [] 122 | x = inp 123 | 124 | if self.resize_input: 125 | x = F.upsample(x, size=(299, 299), mode='bilinear') 126 | 127 | if self.normalize_input: 128 | x = x.clone() 129 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 130 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 131 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 132 | 133 | for idx, block in enumerate(self.blocks): 134 | x = block(x) 135 | if idx in self.output_blocks: 136 | outp.append(x) 137 | 138 | if idx == self.last_needed_block: 139 | break 140 | 141 | return outp -------------------------------------------------------------------------------- /utils/inception_score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from torch.nn import functional as F 5 | import torch.utils.data 6 | 7 | from torchvision.models.inception import inception_v3 8 | 9 | import numpy as np 10 | from scipy.stats import entropy 11 | 12 | def inception_score(imgs, cuda=True, batch_size=16, resize=False, splits=1): 13 | """Computes the inception score of the generated images imgs 14 | 15 | imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1] 16 | cuda -- whether or not to run on GPU 17 | batch_size -- batch size for feeding into Inception v3 18 | splits -- number of splits 19 | """ 20 | N = len(imgs) 21 | 22 | assert batch_size > 0 23 | assert N > batch_size 24 | 25 | # Set up dtype 26 | if cuda: 27 | dtype = torch.cuda.FloatTensor 28 | else: 29 | if torch.cuda.is_available(): 30 | print("WARNING: You have a CUDA device, so you should probably set cuda=True") 31 | dtype = torch.FloatTensor 32 | 33 | # Set up dataloader 34 | dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) 35 | 36 | # Load inception model 37 | inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype) 38 | inception_model.eval(); 39 | up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype) 40 | def get_pred(x): 41 | if resize: 42 | x = up(x) 43 | x = inception_model(x) 44 | return F.softmax(x).data.cpu().numpy() 45 | 46 | # Get predictions 47 | preds = np.zeros((N, 1000)) 48 | 49 | for i, batch in enumerate(dataloader, 0): 50 | batch = batch.type(dtype) 51 | batchv = Variable(batch) 52 | batch_size_i = batch.size()[0] 53 | 54 | preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv) 55 | 56 | # Now compute the mean kl-div 57 | split_scores = [] 58 | 59 | for k in range(splits): 60 | part = preds[k * (N // splits): (k+1) * (N // splits), :] 61 | py = np.mean(part, axis=0) 62 | scores = [] 63 | for i in range(part.shape[0]): 64 | pyx = part[i, :] 65 | scores.append(entropy(pyx, py)) 66 | split_scores.append(np.exp(np.mean(scores))) 67 | 68 | return np.mean(split_scores), np.std(split_scores) 69 | 70 | if __name__ == '__main__': 71 | class IgnoreLabelDataset(torch.utils.data.Dataset): 72 | def __init__(self, orig): 73 | self.orig = orig 74 | 75 | def __getitem__(self, index): 76 | return self.orig[index][0] 77 | 78 | def __len__(self): 79 | return len(self.orig) 80 | 81 | import torchvision.datasets as dset 82 | import torchvision.transforms as transforms 83 | 84 | cifar = dset.CIFAR10(root='data/', download=True, 85 | transform=transforms.Compose([ 86 | transforms.Scale(32), 87 | transforms.ToTensor(), 88 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 89 | ]) 90 | ) 91 | 92 | IgnoreLabelDataset(cifar) 93 | 94 | print ("Calculating Inception Score...") 95 | print (inception_score(IgnoreLabelDataset(cifar), cuda=True, batch_size=32, resize=True, splits=10)) -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | import tensorflow as tf 3 | import numpy as np 4 | import scipy.misc 5 | try: 6 | from StringIO import StringIO # Python 2.7 7 | except ImportError: 8 | from io import BytesIO # Python 3.x 9 | 10 | 11 | class Logger(object): 12 | 13 | def __init__(self, log_dir): 14 | """Create a summary writer logging to log_dir.""" 15 | self.writer = tf.summary.FileWriter(log_dir) 16 | 17 | def scalar_summary(self, tag, value, step): 18 | """Log a scalar variable.""" 19 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 20 | self.writer.add_summary(summary, step) 21 | 22 | def text_summary(self, tag, value): 23 | """Log a text variable.""" 24 | text_tensor = tf.make_tensor_proto(value, dtype=tf.string) 25 | meta = tf.SummaryMetadata() 26 | meta.plugin_data.plugin_name = "text" 27 | 28 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, metadata=meta, tensor=text_tensor)]) 29 | self.writer.add_summary(summary) 30 | 31 | def image_summary(self, tag, images, step): 32 | """Log a list of images.""" 33 | 34 | img_summaries = [] 35 | for i, img in enumerate(images): 36 | # Write the image to a string 37 | try: 38 | s = StringIO() 39 | except: 40 | s = BytesIO() 41 | scipy.misc.toimage(img).save(s, format="png") 42 | 43 | # Create an Image object 44 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 45 | height=img.shape[0], 46 | width=img.shape[1]) 47 | # Create a Summary value 48 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 49 | 50 | # Create and write Summary 51 | summary = tf.Summary(value=img_summaries) 52 | self.writer.add_summary(summary, step) 53 | 54 | def histo_summary(self, tag, values, step, bins=1000): 55 | """Log a histogram of the tensor of values.""" 56 | 57 | # Create a histogram using numpy 58 | counts, bin_edges = np.histogram(values, bins=bins) 59 | 60 | # Fill the fields of the histogram proto 61 | hist = tf.HistogramProto() 62 | hist.min = float(np.min(values)) 63 | hist.max = float(np.max(values)) 64 | hist.num = int(np.prod(values.shape)) 65 | hist.sum = float(np.sum(values)) 66 | hist.sum_squares = float(np.sum(values**2)) 67 | 68 | # Drop the start of the first bin 69 | bin_edges = bin_edges[1:] 70 | 71 | # Add bin edges and counts 72 | for edge in bin_edges: 73 | hist.bucket_limit.append(edge) 74 | for c in counts: 75 | hist.bucket.append(c) 76 | 77 | # Create and write Summary 78 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 79 | self.writer.add_summary(summary, step) 80 | self.writer.flush() -------------------------------------------------------------------------------- /utils/spectral_normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer, required 3 | 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from torch import Tensor 8 | from torch.nn import Parameter 9 | 10 | def l2normalize(v, eps=1e-12): 11 | return v / (v.norm() + eps) 12 | 13 | 14 | class SpectralNorm(nn.Module): 15 | def __init__(self, module, name='weight', power_iterations=1): 16 | super(SpectralNorm, self).__init__() 17 | self.module = module 18 | self.name = name 19 | self.power_iterations = power_iterations 20 | if not self._made_params(): 21 | self._make_params() 22 | 23 | def _update_u_v(self): 24 | u = getattr(self.module, self.name + "_u") 25 | v = getattr(self.module, self.name + "_v") 26 | w = getattr(self.module, self.name + "_bar") 27 | 28 | height = w.data.shape[0] 29 | for _ in range(self.power_iterations): 30 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) 31 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) 32 | 33 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 34 | sigma = u.dot(w.view(height, -1).mv(v)) 35 | setattr(self.module, self.name, w / sigma.expand_as(w)) 36 | 37 | def _made_params(self): 38 | try: 39 | u = getattr(self.module, self.name + "_u") 40 | v = getattr(self.module, self.name + "_v") 41 | w = getattr(self.module, self.name + "_bar") 42 | return True 43 | except AttributeError: 44 | return False 45 | 46 | 47 | def _make_params(self): 48 | w = getattr(self.module, self.name) 49 | 50 | height = w.data.shape[0] 51 | width = w.view(height, -1).data.shape[1] 52 | 53 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 54 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 55 | u.data = l2normalize(u.data) 56 | v.data = l2normalize(v.data) 57 | w_bar = Parameter(w.data) 58 | 59 | del self.module._parameters[self.name] 60 | 61 | self.module.register_parameter(self.name + "_u", u) 62 | self.module.register_parameter(self.name + "_v", v) 63 | self.module.register_parameter(self.name + "_bar", w_bar) 64 | 65 | 66 | def forward(self, *args): 67 | self._update_u_v() 68 | return self.module.forward(*args) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import numpy as np 3 | from copy import deepcopy 4 | import torch 5 | from tqdm import tqdm 6 | import torch.nn as nn 7 | import torch.nn.init as nninit 8 | 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value""" 11 | def __init__(self): 12 | self.reset() 13 | 14 | def reset(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | self.val = val 22 | self.sum += val * n 23 | self.count += n 24 | self.avg = self.sum / self.count 25 | 26 | class defaultlist(list): 27 | def __init__(self, fx): 28 | self._fx = fx 29 | def _fill(self, index): 30 | while len(self) <= index: 31 | self.append(self._fx()) 32 | def __setitem__(self, index, value): 33 | self._fill(index) 34 | list.__setitem__(self, index, value) 35 | def __getitem__(self, index): 36 | self._fill(index) 37 | return list.__getitem__(self, index) 38 | 39 | 40 | def adjust_learning_rate(optimizer, epoch): 41 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 42 | lr = args.lr * (0.1 ** (epoch // 30)) 43 | for param_group in optimizer.param_groups: 44 | param_group['lr'] = lr 45 | 46 | 47 | def accuracy(output, target, topk=(1,)): 48 | """Computes the accuracy over the k top predictions for the specified values of k""" 49 | with torch.no_grad(): 50 | maxk = max(topk) 51 | batch_size = target.size(0) 52 | 53 | _, pred = output.topk(maxk, 1, True, True) 54 | pred = pred.t() 55 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 56 | 57 | res = [] 58 | for k in topk: 59 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 60 | res.append(correct_k.mul_(100.0 / batch_size)) 61 | return res 62 | 63 | # custom weights initialization called on netG and netD 64 | def weights_init_g(m): 65 | classname = m.__class__.__name__ 66 | relu_gain = nninit.calculate_gain('relu') 67 | if classname.find('Linear') == -1: 68 | gain = relu_gain 69 | else: 70 | gain = 1.0 71 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)): 72 | print(classname) 73 | nninit.xavier_uniform_(m.weight.data, gain=gain) 74 | if m.bias is not None: 75 | m.bias.data.zero_() 76 | 77 | 78 | def weights_init(m): 79 | classname = m.__class__.__name__ 80 | if classname.find('Conv') != -1: 81 | m.weight.data.normal_(0.0, 0.02) 82 | elif classname.find('BatchNorm') != -1: 83 | m.weight.data.normal_(1.0, 0.02) 84 | m.bias.data.fill_(0) 85 | 86 | # compute the current classification accuracy 87 | def compute_acc(preds, labels): 88 | correct = 0 89 | preds_ = preds.data.max(1)[1] 90 | correct = preds_.eq(labels.data).cpu().sum() 91 | acc = float(correct) / float(len(labels.data)) * 100.0 92 | return acc 93 | ######################################################################################################################## 94 | 95 | def print_model_report(model): 96 | print('-'*100) 97 | print(model) 98 | print('Dimensions =',end=' ') 99 | count=0 100 | for p in model.parameters(): 101 | print(p.size(),end=' ') 102 | count+=np.prod(p.size()) 103 | print() 104 | print('Num parameters = %s'%(human_format(count))) 105 | print('-'*100) 106 | return count 107 | 108 | def human_format(num): 109 | magnitude=0 110 | while abs(num)>=1000: 111 | magnitude+=1 112 | num/=1000.0 113 | return '%.1f%s'%(num,['','K','M','G','T','P'][magnitude]) 114 | 115 | def print_optimizer_config(optim): 116 | if optim is None: 117 | print(optim) 118 | else: 119 | print(optim,'=',end=' ') 120 | opt=optim.param_groups[0] 121 | for n in opt.keys(): 122 | if not n.startswith('param'): 123 | print(n+':',opt[n],end=', ') 124 | print() 125 | return 126 | 127 | ######################################################################################################################## 128 | 129 | def get_model(model): 130 | return deepcopy(model.state_dict()) 131 | 132 | def set_model_(model,state_dict): 133 | model.load_state_dict(deepcopy(state_dict)) 134 | return 135 | 136 | def freeze_model(model): 137 | for param in model.parameters(): 138 | param.requires_grad = False 139 | return 140 | 141 | ######################################################################################################################## 142 | 143 | def compute_conv_output_size(Lin,kernel_size,stride=1,padding=0,dilation=1): 144 | return int(np.floor((Lin+2*padding-dilation*(kernel_size-1)-1)/float(stride)+1)) 145 | 146 | ######################################################################################################################## 147 | 148 | def compute_mean_std_dataset(dataset): 149 | # dataset already put ToTensor 150 | mean=0 151 | std=0 152 | loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False) 153 | for image, _ in loader: 154 | mean+=image.mean(3).mean(2) 155 | mean /= len(dataset) 156 | 157 | mean_expanded=mean.view(mean.size(0),mean.size(1),1,1).expand_as(image) 158 | for image, _ in loader: 159 | std+=(image-mean_expanded).pow(2).sum(3).sum(2) 160 | 161 | std=(std/(len(dataset)*image.size(2)*image.size(3)-1)).sqrt() 162 | 163 | return mean, std 164 | 165 | ######################################################################################################################## 166 | 167 | def fisher_matrix_diag(t,x,y,model,criterion,sbatch=20): 168 | # Init 169 | fisher={} 170 | for n,p in model.named_parameters(): 171 | fisher[n]=0*p.data 172 | # Compute 173 | model.train() 174 | for i in tqdm(range(0,x.size(0),sbatch),desc='Fisher diagonal',ncols=100,ascii=True): 175 | b=torch.LongTensor(np.arange(i,np.min([i+sbatch,x.size(0)]))).cuda() 176 | images=torch.autograd.Variable(x[b],volatile=False) 177 | target=torch.autograd.Variable(y[b],volatile=False) 178 | # Forward and backward 179 | model.zero_grad() 180 | outputs=model.forward(images) 181 | loss=criterion(t,outputs[t],target) 182 | loss.backward() 183 | # Get gradients 184 | for n,p in model.named_parameters(): 185 | if p.grad is not None: 186 | fisher[n]+=sbatch*p.grad.data.pow(2) 187 | # Mean 188 | for n,_ in model.named_parameters(): 189 | fisher[n]=fisher[n]/x.size(0) 190 | fisher[n]=torch.autograd.Variable(fisher[n],requires_grad=False) 191 | return fisher 192 | 193 | ######################################################################################################################## 194 | 195 | def cross_entropy(outputs,targets,exp=1,size_average=True,eps=1e-5): 196 | out=torch.nn.functional.softmax(outputs) 197 | tar=torch.nn.functional.softmax(targets) 198 | if exp!=1: 199 | out=out.pow(exp) 200 | out=out/out.sum(1).view(-1,1).expand_as(out) 201 | tar=tar.pow(exp) 202 | tar=tar/tar.sum(1).view(-1,1).expand_as(tar) 203 | out=out+eps/out.size(1) 204 | out=out/out.sum(1).view(-1,1).expand_as(out) 205 | ce=-(tar*out.log()).sum(1) 206 | if size_average: 207 | ce=ce.mean() 208 | return ce 209 | 210 | ######################################################################################################################## 211 | 212 | def set_req_grad(layer,req_grad): 213 | if hasattr(layer,'weight'): 214 | layer.weight.requires_grad=req_grad 215 | if hasattr(layer,'bias'): 216 | layer.bias.requires_grad=req_grad 217 | return 218 | 219 | ######################################################################################################################## 220 | 221 | def is_number(s): 222 | try: 223 | float(s) 224 | return True 225 | except ValueError: 226 | pass 227 | 228 | try: 229 | import unicodedata 230 | unicodedata.numeric(s) 231 | return True 232 | except (TypeError, ValueError): 233 | pass 234 | 235 | return False 236 | ######################################################################################################################## 237 | 238 | import numpy as np 239 | 240 | 241 | def get_im2col_indices(x_shape, field_height, field_width, padding=1, stride=1): 242 | # First figure out what the size of the output should be 243 | N, C, H, W = x_shape 244 | assert (H + 2 * padding - field_height) % stride == 0 245 | assert (W + 2 * padding - field_height) % stride == 0 246 | out_height = (H + 2 * padding - field_height) // stride + 1 247 | out_width = (W + 2 * padding - field_width) // stride + 1 248 | 249 | i0 = np.repeat(np.arange(field_height), field_width) 250 | i0 = np.tile(i0, C) 251 | i1 = stride * np.repeat(np.arange(out_height), out_width) 252 | j0 = np.tile(np.arange(field_width), field_height * C) 253 | j1 = stride * np.tile(np.arange(out_width), out_height) 254 | i = i0.reshape(-1, 1) + i1.reshape(1, -1) 255 | j = j0.reshape(-1, 1) + j1.reshape(1, -1) 256 | 257 | k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1) 258 | 259 | return (k, i, j) 260 | 261 | 262 | def im2col_indices(x, field_height, field_width, padding=1, stride=1): 263 | """ An implementation of im2col based on some fancy indexing """ 264 | # Zero-pad the input 265 | p = padding 266 | #x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant') 267 | pad = torch.nn.ConstantPad3d((p,p,p,p),0) 268 | x_padded = pad(x) 269 | 270 | #print(x_padded.shape) 271 | 272 | k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding, 273 | stride) 274 | #print(k, i, j) 275 | 276 | cols = x_padded[:, k, i, j] 277 | C = x.shape[1] 278 | cols = cols.permute(1, 2, 0).contiguous() 279 | #print(cols.shape) 280 | #print(field_height * field_width * C) 281 | cols= cols.view(field_height * field_width * C, -1) 282 | #print(cols.shape) 283 | return cols 284 | 285 | 286 | def col2im_indices(cols, x_shape, field_height=3, field_width=3, padding=1, 287 | stride=1): 288 | """ An implementation of col2im based on fancy indexing and np.add.at """ 289 | N, C, H, W = x_shape 290 | H_padded, W_padded = H + 2 * padding, W + 2 * padding 291 | x_padded = np.zeros((N, C, H_padded, W_padded), dtype=cols.dtype) 292 | k, i, j = get_im2col_indices(x_shape, field_height, field_width, padding, 293 | stride) 294 | cols_reshaped = cols.reshape(C * field_height * field_width, -1, N) 295 | cols_reshaped = cols_reshaped.transpose(2, 0, 1) 296 | np.add.at(x_padded, (slice(None), k, i, j), cols_reshaped) 297 | if padding == 0: 298 | return x_padded 299 | return x_padded[:, :, padding:-padding, padding:-padding] 300 | 301 | pass 302 | 303 | class ZeroOneNorm(object): 304 | 305 | def __call__(self, *inputs): 306 | outputs = [] 307 | for idx, _input in enumerate(inputs): 308 | _max_val = torch.abs(_input) 309 | _max_val=_max_val.max() 310 | _input = _input / _max_val 311 | outputs.append(_input) 312 | return outputs if idx > 1 else outputs[0] 313 | --------------------------------------------------------------------------------