├── .gitignore ├── Classification ├── README.md ├── main.py ├── model.py ├── runs │ ├── ICP_cifar100_densenet40 │ │ └── events.out.tfevents.1555916855.MAC-U3S2 │ ├── ICP_cifar100_googlenet │ │ └── events.out.tfevents.1556190825.MAC-U3S2 │ ├── ICP_cifar100_resnet20 │ │ └── events.out.tfevents.1556005874.MAC-U3S2 │ ├── ICP_cifar100_vgg16 │ │ └── events.out.tfevents.1555905801.MAC-U3S2 │ ├── ICP_cifar10_densenet40 │ │ └── events.out.tfevents.1555551815.MAC-U3S2 │ ├── ICP_cifar10_googlenet │ │ └── events.out.tfevents.1556112349.MAC-U3S2 │ ├── ICP_cifar10_resnet20 │ │ ├── events.out.tfevents.1555501545.MAC-U3S2 │ │ ├── events.out.tfevents.1569207192.MAC-U3S2 │ │ └── events.out.tfevents.1569214119.MAC-U3S2 │ └── ICP_cifar10_vgg16 │ │ └── events.out.tfevents.1555379025.MAC-U3S2 ├── scripts │ ├── Test_cifar100_densenet40.sh │ ├── Test_cifar100_googlenet.sh │ ├── Test_cifar100_resnet20.sh │ ├── Test_cifar100_vgg16.sh │ ├── Test_cifar10_densenet40.sh │ ├── Test_cifar10_googlenet.sh │ ├── Test_cifar10_resnet20.sh │ ├── Test_cifar10_vgg16.sh │ ├── Train_cifar100_densenet40.sh │ ├── Train_cifar100_googlenet.sh │ ├── Train_cifar100_resnet20.sh │ ├── Train_cifar100_vgg16.sh │ ├── Train_cifar10_densenet40.sh │ ├── Train_cifar10_googlenet.sh │ ├── Train_cifar10_resnet20.sh │ └── Train_cifar10_vgg16.sh └── utils.py ├── Disentanglement ├── MIG_Score │ ├── disentanglement_metrics.py │ ├── metric_helpers │ │ ├── loader.py │ │ └── mi_metric.py │ └── model.py ├── README.md ├── dataset.py ├── main.py ├── model.py ├── scripts │ ├── MIG_dsprites.sh │ ├── MIG_faces.sh │ ├── Test_celeba.sh │ ├── Test_dsprites.sh │ ├── Test_faces.sh │ ├── Train_celeba.sh │ ├── Train_dsprites.sh │ ├── Train_faces.sh │ └── prepare_data.sh ├── solver.py └── utils.py ├── Paper ├── poster.pdf └── 信息竞争式的多样化特征学习.pdf └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /Classification/README.md: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | - python3 3 | - pytorch 1.0 4 | - tensorboardX 5 | 6 | # Prepare Datasets 7 | The datasets of cifar10 and cifar100 will be downloaded automatically while runing the codes. 8 | 9 | # Training 10 | To train a model such as vgg16 on cifar10 with ICP: 11 | > chmod +x ./scripts/Train_cifar10_vgg16.sh 12 | 13 | > sh ./scripts/Train_cifar10_vgg16.sh 14 | 15 | # Testing 16 | To test a trained model such as vgg16 on cifar10: 17 | > chmod +x ./scripts/Test_cifar10_vgg16.sh 18 | 19 | > sh ./scripts/Test_cifar10_vgg16.sh 20 | 21 | # Results and Logs 22 | The error rates of ICP on cifar-10: 23 | 24 | | | VGG16 | GoogLeNet | ResNet20 | DenseNet40 | 25 | | :---: |:--------:|:--------: |:-------: |:-------: | 26 | |Baseline |6.67 |4.92 |7.63 |5.83 | 27 | |ICP-ALL |6.97 |4.76 |6.47 |6.13 | 28 | |ICP-COM |6.59 |4.67 |7.33 |5.63 | 29 | |**ICP** |**6.10** |**4.26** |**6.01** |**4.99** | 30 | 31 | The error rates of ICP on cifar-100: 32 | 33 | | | VGG16 | GoogLeNet | ResNet20 | DenseNet40 | 34 | | :---: |:--------:|:--------: |:-------: |:-------: | 35 | |Baseline |26.41 |20.68 |31.91 |27.55 | 36 | |ICP-ALL |26.73 |20.90 |28.35 |27.51 | 37 | |ICP-COM |26.37 |20.81 |32.76 |26.85 | 38 | |**ICP** |**24.54** |**18.55** |**28.13** |**24.52** | 39 | 40 | Baseline denotes the performance of original model, ICP-ALL denotes the result of ICP without all the information constraints, ICP-COM denotes the results of ICP without the competing constraints. 41 | 42 | The logs of getting our paper's results such as vgg16 on cifar10 with ICP can be shown by: 43 | > tensorboard --logdir runs/ICP_cifar10_vgg 44 | 45 | # Trained Models 46 | The trained models of getting our paper's results can be download by [Baidu Netdisk](https://pan.baidu.com/s/1JLQrOvVWbWIXzu_A2l4Ccw) (Password: vd3i), or [Google Drive](https://drive.google.com/drive/folders/19mBHxAVYALPzIQLvvL0uU9-XMLEttBc6?usp=sharing). 47 | -------------------------------------------------------------------------------- /Classification/main.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import torch.backends.cudnn as cudnn 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | import os 10 | import argparse 11 | import numpy as np 12 | from utils import progress_bar, str2bool 13 | from tensorboardX import SummaryWriter 14 | 15 | parser = argparse.ArgumentParser(description='ICP CIFAR10/CIFAR100 Training.') 16 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 17 | parser.add_argument('--train', default=True, type=str2bool, help='train or test') 18 | parser.add_argument('--dataset', default='cifar10', type=str, help='dataset name: [cifar10, cifar100].') 19 | parser.add_argument('--model', default='vgg16', type=str, help='model: [vgg16, googlenet, resnet20, densenet40].') 20 | parser.add_argument('--epoch', default=90, type=int, help='the number of epoch.') 21 | parser.add_argument('--lr_decay_epochs', default=[30,60], nargs='+', type=int, help='the epoch to decay the learning rate.') 22 | 23 | parser.add_argument('--gamma', default=0.01, type=float, help='Compete - MLP') 24 | parser.add_argument('--alpha', default=0.01, type=float, help='Max - DIS') 25 | parser.add_argument('--beta', default=0.001, type=float, help='MIN - KL') 26 | parser.add_argument('--rec', default=0.1, type=float, help='Synergy - REC') 27 | #parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 28 | args = parser.parse_args() 29 | print(args) 30 | 31 | if args.train: 32 | writer = SummaryWriter('runs/ICP_{}_{}'.format(args.dataset, args.model)) 33 | 34 | torch.manual_seed(233) 35 | torch.cuda.manual_seed(233) 36 | np.random.seed(233) 37 | 38 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 39 | best_acc = 0 # best test accuracy 40 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 41 | if device == 'cuda': 42 | cudnn.benchmark = True 43 | 44 | # Data 45 | print('==> Preparing data..') 46 | if args.dataset.lower() == 'cifar10': 47 | transform_train = transforms.Compose([ 48 | transforms.RandomCrop(32, padding=4), 49 | transforms.RandomHorizontalFlip(), 50 | transforms.ToTensor(), 51 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 52 | ]) 53 | 54 | transform_test = transforms.Compose([ 55 | transforms.ToTensor(), 56 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 57 | ]) 58 | 59 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 60 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)#, drop_last = True 61 | 62 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 63 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 64 | 65 | # Model 66 | print('==> Building model..') 67 | netE, netG_Less, netG_More, netG_Total, netD, netZ2Y, netY2Z = \ 68 | get_model(args.dataset, args.model, device) 69 | 70 | elif args.dataset.lower() == 'cifar100': 71 | transform_train = transforms.Compose([ 72 | transforms.RandomCrop(32, padding=4), 73 | transforms.RandomHorizontalFlip(), 74 | transforms.RandomRotation(15), 75 | transforms.ToTensor(), 76 | transforms.Normalize((0.5070751592371323, 0.48654887331495095, 0.4409178433670343), (0.2673342858792401, 0.2564384629170883, 0.27615047132568404))]) 77 | 78 | transform_test = transforms.Compose([ 79 | transforms.ToTensor(), 80 | transforms.Normalize((0.5070751592371323, 0.48654887331495095, 0.4409178433670343), (0.2673342858792401, 0.2564384629170883, 0.27615047132568404))]) 81 | 82 | trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) 83 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 84 | 85 | testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test) 86 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 87 | 88 | # Model 89 | print('==> Building model..') 90 | netE, netG_Less, netG_More, netG_Total, netD, netZ2Y, netY2Z = \ 91 | get_model(args.dataset, args.model, device) 92 | 93 | else: 94 | raise NotImplementedError 95 | 96 | 97 | criterion = nn.CrossEntropyLoss() 98 | loss_MSE = nn.MSELoss() 99 | def loss_KL(mu, logvar): 100 | kld = torch.mean(-0.5*(1+logvar-mu**2-torch.exp(logvar)).sum(1)) 101 | return kld 102 | 103 | optimizerG = optim.SGD([{'params' : netE.parameters()}, 104 | {'params' : netG_Less.parameters()}, 105 | {'params' : netG_More.parameters()}, 106 | {'params' : netG_Total.parameters()}], lr=args.lr, momentum=0.9, weight_decay=5e-4) 107 | optimizerD = optim.SGD(netD.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 108 | optimizerMLP = optim.SGD([{'params' : netZ2Y.parameters()}, 109 | {'params' : netY2Z.parameters()}], lr=args.lr, momentum=0.9, weight_decay=5e-4) 110 | 111 | schedulerG = optim.lr_scheduler.MultiStepLR(optimizerG, milestones = args.lr_decay_epochs, gamma=0.1) 112 | schedulerD = optim.lr_scheduler.MultiStepLR(optimizerD, milestones = args.lr_decay_epochs, gamma=0.1) 113 | schedulerMLP = optim.lr_scheduler.MultiStepLR(optimizerMLP, milestones = args.lr_decay_epochs, gamma=0.1) 114 | 115 | # hyper-parameters 116 | gamma = args.gamma 117 | alpha = args.alpha 118 | beta = args.beta 119 | rec = args.rec 120 | 121 | # Train 122 | def train(epoch): 123 | print('\nEpoch: %d' % epoch) 124 | netE.train() 125 | netG_Less.train() 126 | netG_More.train() 127 | netG_Total.train() 128 | netD.train() 129 | netZ2Y.train() 130 | netY2Z.train() 131 | 132 | train_loss = 0 133 | train_MLP = 0 134 | train_D = 0 135 | train_kl = 0 136 | train_less_rec = 0 137 | train_more_rec = 0 138 | train_total_rec = 0 139 | 140 | correct = 0 141 | total = 0 142 | for batch_idx, (x, targets) in enumerate(trainloader): 143 | x, targets = x.to(device), targets.to(device) 144 | 145 | ## Update MLP 146 | z, mu, logvar, y = netE(x) 147 | rec_y = netZ2Y(z) 148 | rec_z = netY2Z(y) 149 | 150 | # loss MLP 151 | # or: (to prevent gradient explosion - nan) 152 | # gamma * (F.mse_loss(rec_z, z.detach(), reduction='sum').div(self.batch_size) \ 153 | # + F.mse_loss(rec_y, y.detach(), reduction='sum').div(self.batch_size)) 154 | loss_MLP = gamma * (loss_MSE(rec_z, z.detach()) + loss_MSE(rec_y, y.detach())) #/ args.batch_size 155 | optimizerMLP.zero_grad() 156 | loss_MLP.backward() 157 | optimizerMLP.step() 158 | 159 | # loss D 160 | index = np.arange(x.size()[0]) 161 | np.random.shuffle(index) 162 | y_shuffle = y.clone() 163 | y_shuffle = y_shuffle[index, :] 164 | 165 | real_score = netD(torch.cat([y.detach(), y.detach()], dim=1)) 166 | fake_score = netD(torch.cat([y.detach(), y_shuffle.detach()], dim=1)) 167 | 168 | # or: (to prevent gradient explosion - nan) 169 | # alpha * (F.binary_cross_entropy(real_score, ones, reduction='sum').div(self.batch_size) \ 170 | # + F.binary_cross_entropy(fake_score, zeros, reduction='sum').div(self.batch_size)) 171 | loss_D = -alpha * torch.mean(torch.log(real_score) + torch.log(1 - fake_score)) 172 | 173 | optimizerD.zero_grad() 174 | loss_D.backward() 175 | optimizerD.step() 176 | 177 | ## Update G 178 | z, mu, logvar, y = netE(x) 179 | 180 | rec_y = netZ2Y(z) 181 | rec_z = netY2Z(y) 182 | 183 | rec_less_x = netG_Less(z) 184 | rec_more_x = netG_More(y) 185 | rec_x = netG_Total(torch.cat([z, y], dim=1)) 186 | 187 | # loss MLP 188 | # or: (to prevent gradient explosion - nan) 189 | # (F.mse_loss(rec_z, z.detach(), reduction='sum').div(self.batch_size) \ 190 | # + F.mse_loss(rec_y, y.detach(), reduction='sum').div(self.batch_size)) 191 | loss_MLP = (loss_MSE(rec_z, z.detach()) + loss_MSE(rec_y, y.detach())) #/ args.batch_size 192 | 193 | # loss D 194 | index = np.arange(x.size()[0]) 195 | np.random.shuffle(index) 196 | y_shuffle = y.clone() 197 | y_shuffle = y_shuffle[index, :] 198 | 199 | real_score = netD(torch.cat([y, y.detach()], dim=1)) 200 | fake_score = netD(torch.cat([y, y_shuffle.detach()], dim=1)) 201 | 202 | # or: (to prevent gradient explosion - nan) 203 | # alpha * (F.binary_cross_entropy(real_score, ones, reduction='sum').div(self.batch_size) \ 204 | # + F.binary_cross_entropy(fake_score, zeros, reduction='sum').div(self.batch_size)) 205 | loss_D = -torch.mean(torch.log(real_score) + torch.log(1 - fake_score)) 206 | 207 | # loss KL 208 | loss_kl = loss_KL(mu, logvar) 209 | 210 | # loss Rec 211 | loss_less_rec = criterion(rec_less_x, targets) 212 | loss_more_rec = criterion(rec_more_x, targets) 213 | loss_total_rec = criterion(rec_x, targets) 214 | loss_rec = loss_less_rec + loss_more_rec + loss_total_rec 215 | 216 | # total Loss 217 | loss_total = rec * loss_rec + beta * loss_kl + alpha * loss_D - gamma * loss_MLP 218 | 219 | optimizerG.zero_grad() 220 | loss_total.backward() 221 | optimizerG.step() 222 | 223 | train_loss += loss_total.item() 224 | train_MLP += loss_MLP.item() 225 | train_D += loss_D.item() 226 | train_kl += loss_kl.item() 227 | train_less_rec += loss_less_rec.item() 228 | train_more_rec += loss_more_rec.item() 229 | train_total_rec += loss_total_rec.item() 230 | 231 | _, predicted = rec_x.max(1) 232 | total += targets.size(0) 233 | correct += predicted.eq(targets).sum().item() 234 | 235 | # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tloss_total: {:.5f}, loss_MLP: {:.5f}, loss_D: {:.5f}, loss_kl: {:.5f}, loss_less_rec: {:.5f}, loss_more_rec: {:.5f}, loss_total_rec: {:.5f}'.format( 236 | # epoch, batch_idx * len(x), len(trainloader.dataset), 237 | # 100. * batch_idx / len(trainloader), 238 | # loss_total.item(), loss_MLP.item(), loss_D.item(), loss_kl.item(), 239 | # loss_less_rec.item(), loss_more_rec.item(), loss_total_rec.item())) 240 | 241 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 242 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 243 | 244 | writer.add_scalar('train_loss', train_loss, epoch) 245 | writer.add_scalar('train_MLP', train_MLP, epoch) 246 | writer.add_scalar('train_D', train_D, epoch) 247 | writer.add_scalar('train_kl', train_kl, epoch) 248 | writer.add_scalar('train_less_rec', train_less_rec, epoch) 249 | writer.add_scalar('train_more_rec', train_more_rec, epoch) 250 | writer.add_scalar('train_total_rec', train_total_rec, epoch) 251 | 252 | # Test 253 | def test(epoch = 0): 254 | global best_acc 255 | netE.eval() 256 | netG_Less.eval() 257 | netG_More.eval() 258 | netG_Total.eval() 259 | netD.eval() 260 | netZ2Y.eval() 261 | netY2Z.eval() 262 | 263 | # Total 264 | test_loss = 0 265 | correct_total = 0 266 | correct_less = 0 267 | correct_more = 0 268 | total = 0 269 | with torch.no_grad(): 270 | for batch_idx, (x, targets) in enumerate(testloader): 271 | x, targets = x.to(device), targets.to(device) 272 | z, mu, logvar, y = netE(x) 273 | outputs = netG_Total(torch.cat([mu, y], dim=1)) 274 | loss = criterion(outputs, targets) 275 | 276 | test_loss += loss.item() 277 | _, predicted = outputs.max(1) 278 | total += targets.size(0) 279 | correct_total += predicted.eq(targets).sum().item() 280 | 281 | outputs_less = netG_Less(mu) 282 | _, predicted_less = outputs_less.max(1) 283 | correct_less += predicted_less.eq(targets).sum().item() 284 | 285 | outputs_more = netG_More(y) 286 | _, predicted_more = outputs_more.max(1) 287 | correct_more += predicted_more.eq(targets).sum().item() 288 | 289 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 290 | % (test_loss/(batch_idx+1), 100.*correct_total/total, correct_total, total)) 291 | 292 | # Save checkpoint. 293 | acc_total = 100.*correct_total/total 294 | 295 | acc_less = 100.*correct_less/total 296 | acc_more = 100.*correct_more/total 297 | 298 | print('Error Rate: ', 100.0 - acc_total) 299 | 300 | if args.train: 301 | writer.add_scalar('accuracy', acc_total, epoch) 302 | writer.add_scalar('accuracy_less', acc_less, epoch) 303 | writer.add_scalar('accuracy_more', acc_more, epoch) 304 | 305 | if acc_total > best_acc: 306 | print('Saving..') 307 | state = { 308 | 'netE': netE.state_dict(), 309 | 'netG_Less': netG_Less.state_dict(), 310 | 'netG_More': netG_More.state_dict(), 311 | 'netG_Total': netG_Total.state_dict(), 312 | 'netD': netD.state_dict(), 313 | 'netZ2Y': netZ2Y.state_dict(), 314 | 'netY2Z': netY2Z.state_dict(), 315 | 316 | 'acc_total': acc_total, 317 | 'acc_less': acc_less, 318 | 'acc_more': acc_more, 319 | 'epoch': epoch, 320 | } 321 | if not os.path.isdir('checkpoints'): 322 | os.mkdir('checkpoints') 323 | torch.save(state, './checkpoints/ICP_{}_{}.t7'.format(args.dataset, args.model)) 324 | best_acc = acc_total 325 | 326 | 327 | if __name__ == '__main__': 328 | 329 | if args.train: 330 | for epoch in range(start_epoch, start_epoch + args.epoch): 331 | schedulerG.step() 332 | schedulerD.step() 333 | schedulerMLP.step() 334 | train(epoch) 335 | test(epoch) 336 | else: 337 | print('==> Resuming from checkpoint..') 338 | assert os.path.isdir('checkpoints'), 'Error: no checkpoint directory found!' 339 | checkpoint = torch.load('./checkpoints/ICP_{}_{}.t7'.format(args.dataset, args.model)) 340 | netE.load_state_dict(checkpoint['netE']) 341 | netG_Less.load_state_dict(checkpoint['netG_Less']) 342 | netG_More.load_state_dict(checkpoint['netG_More']) 343 | netG_Total.load_state_dict(checkpoint['netG_Total']) 344 | netD.load_state_dict(checkpoint['netD']) 345 | netZ2Y.load_state_dict(checkpoint['netZ2Y']) 346 | netY2Z.load_state_dict(checkpoint['netY2Z']) 347 | print("=> loaded checkpoint ICP_{}_{}.".format(args.dataset, args.model)) 348 | test() 349 | -------------------------------------------------------------------------------- /Classification/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def reparametrize(mu, logvar): 7 | std = logvar.mul(0.5).exp_() 8 | eps = std.data.new(std.size()).normal_() 9 | return mu + std * eps 10 | 11 | class Discriminator(nn.Module): 12 | def __init__(self, y_dim): 13 | super(Discriminator, self).__init__() 14 | self.net = nn.Sequential( 15 | nn.Linear(y_dim * 2, y_dim), 16 | nn.LeakyReLU(0.2, True), 17 | nn.Linear(y_dim, y_dim), 18 | nn.LeakyReLU(0.2, True), 19 | nn.Linear(y_dim, 1), 20 | nn.Sigmoid() 21 | ) 22 | 23 | def forward(self, y): 24 | return self.net(y).squeeze() 25 | 26 | class MLP(nn.Module): 27 | def __init__(self, s_dim, t_dim): 28 | super(MLP, self).__init__() 29 | self.net = nn.Sequential( 30 | nn.Linear(s_dim, s_dim), 31 | nn.LeakyReLU(0.2, True), 32 | nn.Linear(t_dim, t_dim), 33 | nn.LeakyReLU(0.2, True), 34 | nn.Linear(t_dim, t_dim), 35 | nn.ReLU() 36 | ) 37 | 38 | def forward(self, s): 39 | t = self.net(s) 40 | return t 41 | 42 | ## ---------------- CIFAR10 ---------------- 43 | # ----------------- VGG ----------------- 44 | cfg = { 45 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 46 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 47 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 48 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 49 | } 50 | 51 | class VGG_FeatureExtractor(nn.Module): 52 | def __init__(self, vgg_name): 53 | super(VGG_FeatureExtractor, self).__init__() 54 | self.z_dim = 512 55 | self.y_dim = 512 56 | 57 | self.feature_z = self._make_layers(cfg[vgg_name], self.z_dim * 2) 58 | self.feature_y = self._make_layers(cfg[vgg_name], self.y_dim) 59 | 60 | 61 | def forward(self, x): 62 | distributions = self.feature_z(x) 63 | distributions = distributions.view(distributions.size(0), -1) 64 | mu = distributions[:, : self.z_dim] 65 | logvar = distributions[:, self.z_dim : ] 66 | z = reparametrize(mu, logvar) 67 | 68 | y = self.feature_y(x) 69 | y = y.view(y.size(0), -1) 70 | return z, mu, logvar, y 71 | 72 | def _make_layers(self, cfg, dim): 73 | layers = [] 74 | in_channels = 3 75 | for x in cfg[:-2]: 76 | if x == 'M': 77 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 78 | else: 79 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 80 | nn.BatchNorm2d(x), 81 | nn.ReLU(inplace=True)] 82 | in_channels = x 83 | 84 | layers += [nn.Conv2d(in_channels, dim, kernel_size=3, padding=1), 85 | nn.BatchNorm2d(dim), 86 | nn.ReLU(inplace=True)] 87 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 88 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 89 | return nn.Sequential(*layers) 90 | 91 | class VGG_Classifier(nn.Module): 92 | def __init__(self, dim = 512, num_classes = 10): 93 | super(VGG_Classifier, self).__init__() 94 | self.dim = dim 95 | self.classifier = nn.Linear(self.dim, num_classes) 96 | 97 | def forward(self, z): 98 | out = self.classifier(z) 99 | return out 100 | 101 | # ----------------- GoogleNet ----------------- 102 | class Inception(nn.Module): 103 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes): 104 | super(Inception, self).__init__() 105 | # 1x1 conv branch 106 | self.b1 = nn.Sequential( 107 | nn.Conv2d(in_planes, n1x1, kernel_size=1), 108 | nn.BatchNorm2d(n1x1), 109 | nn.ReLU(True), 110 | ) 111 | 112 | # 1x1 conv -> 3x3 conv branch 113 | self.b2 = nn.Sequential( 114 | nn.Conv2d(in_planes, n3x3red, kernel_size=1), 115 | nn.BatchNorm2d(n3x3red), 116 | nn.ReLU(True), 117 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1), 118 | nn.BatchNorm2d(n3x3), 119 | nn.ReLU(True), 120 | ) 121 | 122 | # 1x1 conv -> 5x5 conv branch 123 | self.b3 = nn.Sequential( 124 | nn.Conv2d(in_planes, n5x5red, kernel_size=1), 125 | nn.BatchNorm2d(n5x5red), 126 | nn.ReLU(True), 127 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1), 128 | nn.BatchNorm2d(n5x5), 129 | nn.ReLU(True), 130 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 131 | nn.BatchNorm2d(n5x5), 132 | nn.ReLU(True), 133 | ) 134 | 135 | # 3x3 pool -> 1x1 conv branch 136 | self.b4 = nn.Sequential( 137 | nn.MaxPool2d(3, stride=1, padding=1), 138 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 139 | nn.BatchNorm2d(pool_planes), 140 | nn.ReLU(True), 141 | ) 142 | 143 | def forward(self, x): 144 | y1 = self.b1(x) 145 | y2 = self.b2(x) 146 | y3 = self.b3(x) 147 | y4 = self.b4(x) 148 | return torch.cat([y1,y2,y3,y4], 1) 149 | 150 | class GoogLeNet_FeatureExtractor(nn.Module): 151 | def __init__(self): 152 | super(GoogLeNet_FeatureExtractor, self).__init__() 153 | self.pre_layers_z = nn.Sequential( 154 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 155 | nn.BatchNorm2d(192), 156 | nn.ReLU(True), 157 | ) 158 | self.a3_z = Inception(192, 64, 96, 128, 16, 32, 32) 159 | self.b3_z = Inception(256, 128, 128, 192, 32, 96, 64) 160 | 161 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 162 | 163 | self.a4_z = Inception(480, 192, 96, 208, 16, 48, 64) 164 | self.b4_z = Inception(512, 160, 112, 224, 24, 64, 64) 165 | self.c4_z = Inception(512, 128, 128, 256, 24, 64, 64) 166 | self.d4_z = Inception(512, 112, 144, 288, 32, 64, 64) 167 | self.e4_z = Inception(528, 256, 160, 320, 32, 128, 128) 168 | 169 | self.a5_z = Inception(832, 256, 160, 320, 32, 128, 128) 170 | self.b5_z_mu = Inception(832, 384, 192, 384, 48, 128, 128) 171 | self.b5_z_logvar = Inception(832, 384, 192, 384, 48, 128, 128) 172 | 173 | self.pre_layers_y = nn.Sequential( 174 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 175 | nn.BatchNorm2d(192), 176 | nn.ReLU(True), 177 | ) 178 | self.a3_y = Inception(192, 64, 96, 128, 16, 32, 32) 179 | self.b3_y = Inception(256, 128, 128, 192, 32, 96, 64) 180 | self.a4_y = Inception(480, 192, 96, 208, 16, 48, 64) 181 | self.b4_y = Inception(512, 160, 112, 224, 24, 64, 64) 182 | self.c4_y = Inception(512, 128, 128, 256, 24, 64, 64) 183 | self.d4_y = Inception(512, 112, 144, 288, 32, 64, 64) 184 | self.e4_y = Inception(528, 256, 160, 320, 32, 128, 128) 185 | self.a5_y = Inception(832, 256, 160, 320, 32, 128, 128) 186 | self.b5_y = Inception(832, 384, 192, 384, 48, 128, 128) 187 | 188 | self.avgpool = nn.AvgPool2d(8, stride=1) 189 | 190 | def forward(self, x): 191 | out_z = self.pre_layers_z(x) 192 | out_z = self.a3_z(out_z) 193 | out_z = self.b3_z(out_z) 194 | out_z = self.maxpool(out_z) 195 | out_z = self.a4_z(out_z) 196 | out_z = self.b4_z(out_z) 197 | out_z = self.c4_z(out_z) 198 | out_z = self.d4_z(out_z) 199 | out_z = self.e4_z(out_z) 200 | out_z = self.maxpool(out_z) 201 | out_z = self.a5_z(out_z) 202 | 203 | mu = self.b5_z_mu(out_z) 204 | mu = self.avgpool(mu) 205 | mu = mu.view(mu.size(0), -1) 206 | 207 | logvar = self.b5_z_logvar(out_z) 208 | logvar = self.avgpool(logvar) 209 | logvar = logvar.view(logvar.size(0), -1) 210 | z = reparametrize(mu, logvar) 211 | 212 | out_y = self.pre_layers_y(x) 213 | out_y = self.a3_y(out_y) 214 | out_y = self.b3_y(out_y) 215 | out_y = self.maxpool(out_y) 216 | out_y = self.a4_y(out_y) 217 | out_y = self.b4_y(out_y) 218 | out_y = self.c4_y(out_y) 219 | out_y = self.d4_y(out_y) 220 | out_y = self.e4_y(out_y) 221 | out_y = self.maxpool(out_y) 222 | out_y = self.a5_y(out_y) 223 | out_y = self.b5_y(out_y) 224 | out_y = self.avgpool(out_y) 225 | y = out_y.view(out_y.size(0), -1) 226 | 227 | return z, mu, logvar, y 228 | 229 | class GoogLeNet_Classifier(nn.Module): 230 | def __init__(self, dim = 1024, num_classes=10): 231 | super(GoogLeNet_Classifier, self).__init__() 232 | self.classifier = nn.Linear(dim, num_classes) 233 | 234 | def forward(self, z): 235 | out = self.classifier(z) 236 | return out 237 | 238 | # ----------------- ResNet ----------------- 239 | class LambdaLayer(nn.Module): 240 | def __init__(self, lambd): 241 | super(LambdaLayer, self).__init__() 242 | self.lambd = lambd 243 | 244 | def forward(self, x): 245 | return self.lambd(x) 246 | 247 | class BasicBlock_Res(nn.Module): 248 | expansion = 1 249 | 250 | def __init__(self, in_planes, planes, stride=1, option='A'): 251 | super(BasicBlock_Res, self).__init__() 252 | 253 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 254 | self.bn1 = nn.BatchNorm2d(planes) 255 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 256 | self.bn2 = nn.BatchNorm2d(planes) 257 | 258 | self.shortcut = nn.Sequential() 259 | if stride != 1 or in_planes != planes: 260 | if option == 'A': 261 | """ 262 | For CIFAR10 ResNet paper uses option A. 263 | """ 264 | self.shortcut = LambdaLayer(lambda x: 265 | F.pad(x[:, :, ::2, ::2], 266 | (0, 0, 0, 0, (planes - x.shape[1])//2, (planes - x.shape[1])//2), "constant", 0)) 267 | elif option == 'B': 268 | self.shortcut = nn.Sequential( 269 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 270 | nn.BatchNorm2d(self.expansion * planes) 271 | ) 272 | 273 | def forward(self, x): 274 | out = F.relu(self.bn1(self.conv1(x))) 275 | self.x = out 276 | out = self.bn2(self.conv2(out)) 277 | out += self.shortcut(x) 278 | out = F.relu(out) 279 | return out 280 | 281 | class ResNet_FeatureExtractor(nn.Module): 282 | def __init__(self, block = BasicBlock_Res, num_blocks = [3, 3, 3]): 283 | super(ResNet_FeatureExtractor, self).__init__() 284 | 285 | self.in_planes = 16 286 | self.layers_z = [nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False), 287 | nn.BatchNorm2d(16)] + \ 288 | self._make_layer(block, 16, num_blocks[0], stride=1) + \ 289 | self._make_layer(block, 32, num_blocks[1], stride=2) + \ 290 | self._make_layer(block, 64 * 2, num_blocks[2], stride=2) 291 | 292 | self.in_planes = 16 293 | self.layers_y = [nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False), 294 | nn.BatchNorm2d(16)] + \ 295 | self._make_layer(block, 16, num_blocks[0], stride=1) + \ 296 | self._make_layer(block, 32, num_blocks[1], stride=2) + \ 297 | self._make_layer(block, 64, num_blocks[2], stride=2) 298 | self.feature_z = nn.Sequential(*self.layers_z) 299 | self.feature_y = nn.Sequential(*self.layers_y) 300 | 301 | def _make_layer(self, block, planes, num_blocks, stride): 302 | strides = [stride] + [1]*(num_blocks-1) 303 | layers = [] 304 | for stride in strides: 305 | layers.append(block(self.in_planes, planes, stride)) 306 | self.in_planes = planes * block.expansion 307 | return layers 308 | 309 | def forward(self, x): 310 | distributions = self.feature_z(x) 311 | distributions = F.avg_pool2d(distributions, 4) 312 | distributions = distributions.view(distributions.size(0), -1) 313 | mu = distributions[:, : 256] 314 | logvar = distributions[:, 256 : ] 315 | z = reparametrize(mu, logvar) 316 | 317 | y = self.feature_y(x) 318 | y = F.avg_pool2d(y, 4) 319 | y = y.view(y.size(0), -1) 320 | return z, mu, logvar, y 321 | 322 | class ResNet_Classifier(nn.Module): 323 | def __init__(self, dim = 256, num_classes=10): 324 | super(ResNet_Classifier, self).__init__() 325 | self.classifier = nn.Linear(dim, num_classes) 326 | 327 | def forward(self, z): 328 | out = self.classifier(z) 329 | return out 330 | 331 | # ----------------- DenseNet ----------------- 332 | class Bottleneck(nn.Module): 333 | def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0): 334 | super(Bottleneck, self).__init__() 335 | planes = expansion * growthRate 336 | self.bn1 = nn.BatchNorm2d(inplanes) 337 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 338 | self.bn2 = nn.BatchNorm2d(planes) 339 | self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3, 340 | padding=1, bias=False) 341 | self.relu = nn.ReLU(inplace=True) 342 | self.dropRate = dropRate 343 | 344 | def forward(self, x): 345 | out = self.bn1(x) 346 | out = self.relu(out) 347 | out = self.conv1(out) 348 | out = self.bn2(out) 349 | out = self.relu(out) 350 | out = self.conv2(out) 351 | if self.dropRate > 0: 352 | out = F.dropout(out, p=self.dropRate, training=self.training) 353 | 354 | out = torch.cat((x, out), 1) 355 | 356 | return out 357 | 358 | class BasicBlock_Dense(nn.Module): 359 | def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0): 360 | super(BasicBlock_Dense, self).__init__() 361 | planes = expansion * growthRate 362 | self.bn1 = nn.BatchNorm2d(inplanes) 363 | self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3, 364 | padding=1, bias=False) 365 | self.relu = nn.ReLU(inplace=True) 366 | self.dropRate = dropRate 367 | 368 | def forward(self, x): 369 | out = self.bn1(x) 370 | out = self.relu(out) 371 | out = self.conv1(out) 372 | if self.dropRate > 0: 373 | out = F.dropout(out, p=self.dropRate, training=self.training) 374 | 375 | out = torch.cat((x, out), 1) 376 | 377 | return out 378 | 379 | class Transition(nn.Module): 380 | def __init__(self, inplanes, outplanes): 381 | super(Transition, self).__init__() 382 | self.bn1 = nn.BatchNorm2d(inplanes) 383 | self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1, 384 | bias=False) 385 | self.relu = nn.ReLU(inplace=True) 386 | 387 | def forward(self, x): 388 | out = self.bn1(x) 389 | out = self.relu(out) 390 | out = self.conv1(out) 391 | out = F.avg_pool2d(out, 2) 392 | return out 393 | 394 | class DenseNet_FeatureExtractor(nn.Module): 395 | def __init__(self, depth=22, block=Bottleneck, dropRate=0, num_classes=10, growthRate=12, compressionRate=2): 396 | super(DenseNet_FeatureExtractor, self).__init__() 397 | 398 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4' 399 | n = (depth - 4) // 3 if block == BasicBlock_Dense else (depth - 4) // 6 400 | 401 | self.growthRate = growthRate 402 | self.dropRate = dropRate 403 | 404 | # self.inplanes is a global variable used across multiple 405 | # helper functions 406 | # z 407 | self.inplanes = growthRate * 2 408 | self.conv1_z = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, 409 | bias=False) 410 | self.dense1_z = self._make_denseblock(block, n) 411 | self.trans1_z = self._make_transition(compressionRate) 412 | self.dense2_z = self._make_denseblock(block, n) 413 | self.trans2_z = self._make_transition(compressionRate) 414 | tmp = self.inplanes 415 | self.dense3_z_mu = self._make_denseblock(block, n) 416 | self.inplanes = tmp 417 | self.dense3_z_logvar = self._make_denseblock(block, n) 418 | self.bn_z_mu = nn.BatchNorm2d(self.inplanes) 419 | self.bn_z_logvar = nn.BatchNorm2d(self.inplanes) 420 | self.relu_z = nn.ReLU(inplace=True) 421 | self.avgpool_z = nn.AvgPool2d(8) 422 | 423 | # y 424 | self.growthRate = growthRate 425 | self.inplanes = growthRate * 2 426 | self.conv1_y = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, 427 | bias=False) 428 | self.dense1_y = self._make_denseblock(block, n) 429 | self.trans1_y = self._make_transition(compressionRate) 430 | self.dense2_y = self._make_denseblock(block, n) 431 | self.trans2_y = self._make_transition(compressionRate) 432 | self.dense3_y = self._make_denseblock(block, n) 433 | self.bn_y = nn.BatchNorm2d(self.inplanes) 434 | self.relu_y = nn.ReLU(inplace=True) 435 | self.avgpool_y = nn.AvgPool2d(8) 436 | 437 | def _make_denseblock(self, block, blocks): 438 | layers = [] 439 | for i in range(blocks): 440 | # Currently we fix the expansion ratio as the default value 441 | layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate)) 442 | self.inplanes += self.growthRate 443 | 444 | return nn.Sequential(*layers) 445 | 446 | def _make_transition(self, compressionRate): 447 | inplanes = self.inplanes 448 | outplanes = int(math.floor(self.inplanes // compressionRate)) 449 | self.inplanes = outplanes 450 | return Transition(inplanes, outplanes) 451 | 452 | def forward(self, x): 453 | out_z = self.conv1_z(x) 454 | 455 | out_z = self.trans1_z(self.dense1_z(out_z)) 456 | out_z = self.trans2_z(self.dense2_z(out_z)) 457 | mu = self.dense3_z_mu(out_z) 458 | mu = self.bn_z_mu(mu) 459 | mu = self.relu_z(mu) 460 | mu = self.avgpool_z(mu) 461 | mu = mu.view(mu.size(0), -1) 462 | 463 | logvar = self.dense3_z_logvar(out_z) 464 | logvar = self.bn_z_logvar(logvar) 465 | logvar = self.relu_z(logvar) 466 | logvar = self.avgpool_z(logvar) 467 | logvar = logvar.view(logvar.size(0), -1) 468 | 469 | # print(mu.shape, logvar.shape) 470 | z = reparametrize(mu, logvar) 471 | 472 | out_y = self.conv1_y(x) 473 | 474 | out_y = self.trans1_y(self.dense1_y(out_y)) 475 | out_y = self.trans2_y(self.dense2_y(out_y)) 476 | out_y = self.dense3_y(out_y) 477 | out_y = self.bn_y(out_y) 478 | out_y = self.relu_y(out_y) 479 | out_y = self.avgpool_y(out_y) 480 | y = out_y.view(out_y.size(0), -1) 481 | # print(z.shape,y.shape) 482 | return z, mu, logvar, y 483 | 484 | class DenseNet_Classifier(nn.Module): 485 | def __init__(self, dim = 456, num_classes=10): 486 | super(DenseNet_Classifier, self).__init__() 487 | self.classifier = nn.Linear(dim, num_classes) 488 | 489 | def forward(self, z): 490 | out = self.classifier(z) 491 | return out 492 | # -------------------- Get Model -------------------- 493 | def get_model(dataset_name, model_name, device): 494 | if dataset_name.lower() == 'cifar10': 495 | num_classes = 10 496 | elif dataset_name.lower() == 'cifar100': 497 | num_classes = 100 498 | else: 499 | raise NotImplementedError 500 | 501 | if model_name.lower() == 'vgg16': 502 | netE = VGG_FeatureExtractor('VGG16').to(device) 503 | netG_Less = VGG_Classifier(dim = 512, num_classes = num_classes).to(device) 504 | netG_More = VGG_Classifier(dim = 512, num_classes = num_classes).to(device) 505 | netG_Total = VGG_Classifier(dim = 512 * 2, num_classes = num_classes).to(device) 506 | netD = Discriminator(512).to(device) 507 | netZ2Y = MLP(s_dim = 512, t_dim = 512).to(device) 508 | netY2Z = MLP(s_dim = 512, t_dim = 512).to(device) 509 | 510 | elif model_name.lower() == 'googlenet': 511 | netE = GoogLeNet_FeatureExtractor(num_classes = num_classes).to(device) 512 | netG_Less = GoogLeNet_Classifier(dim = 1024, num_classes = num_classes).to(device) 513 | netG_More = GoogLeNet_Classifier(dim = 1024, num_classes = num_classes).to(device) 514 | netG_Total = GoogLeNet_Classifier(dim = 1024 * 2, num_classes = num_classes).to(device) 515 | 516 | netD = Discriminator(1024).to(device) 517 | netZ2Y = MLP(s_dim = 1024, t_dim = 1024).to(device) 518 | netY2Z = MLP(s_dim = 1024, t_dim = 1024).to(device) 519 | if device == 'cuda': 520 | netE = torch.nn.DataParallel(netE) 521 | netG_Less = torch.nn.DataParallel(netG_Less) 522 | netG_More = torch.nn.DataParallel(netG_More) 523 | netG_Total = torch.nn.DataParallel(netG_Total) 524 | netD = torch.nn.DataParallel(netD) 525 | netZ2Y = torch.nn.DataParallel(netZ2Y) 526 | netY2Z = torch.nn.DataParallel(netY2Z) 527 | 528 | elif model_name.lower() == 'resnet20': 529 | netE = ResNet_FeatureExtractor().to(device) 530 | netG_Less = ResNet_Classifier(dim = 256, num_classes = num_classes).to(device) 531 | netG_More = ResNet_Classifier(dim = 256, num_classes = num_classes).to(device) 532 | netG_Total = ResNet_Classifier(dim = 256 * 2, num_classes = num_classes).to(device) 533 | 534 | netD = Discriminator(256).to(device) 535 | netZ2Y = MLP(s_dim = 256, t_dim = 256).to(device) 536 | netY2Z = MLP(s_dim = 256, t_dim = 256).to(device) 537 | 538 | elif model_name.lower() == 'densenet40': 539 | netE = DenseNet_FeatureExtractor(block=BasicBlock_Dense, depth=40, dropRate=0, num_classes=num_classes, growthRate=12, compressionRate=1).to(device) 540 | netG_Less = DenseNet_Classifier(dim = 456, num_classes = num_classes).to(device) 541 | netG_More = DenseNet_Classifier(dim = 456, num_classes = num_classes).to(device) 542 | netG_Total = DenseNet_Classifier(dim = 456 * 2, num_classes = num_classes).to(device) 543 | 544 | netD = Discriminator(456).to(device) 545 | netZ2Y = MLP(s_dim = 456, t_dim = 456).to(device) 546 | netY2Z = MLP(s_dim = 456, t_dim = 456).to(device) 547 | 548 | else: 549 | raise NotImplementedError 550 | 551 | return netE, netG_Less, netG_More, netG_Total, netD, netZ2Y, netY2Z 552 | -------------------------------------------------------------------------------- /Classification/runs/ICP_cifar100_densenet40/events.out.tfevents.1555916855.MAC-U3S2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hujiecpp/InformationCompetingProcess/d5ce61f8af647ee15a8ce55e7ad71c373e48435e/Classification/runs/ICP_cifar100_densenet40/events.out.tfevents.1555916855.MAC-U3S2 -------------------------------------------------------------------------------- /Classification/runs/ICP_cifar100_googlenet/events.out.tfevents.1556190825.MAC-U3S2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hujiecpp/InformationCompetingProcess/d5ce61f8af647ee15a8ce55e7ad71c373e48435e/Classification/runs/ICP_cifar100_googlenet/events.out.tfevents.1556190825.MAC-U3S2 -------------------------------------------------------------------------------- /Classification/runs/ICP_cifar100_resnet20/events.out.tfevents.1556005874.MAC-U3S2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hujiecpp/InformationCompetingProcess/d5ce61f8af647ee15a8ce55e7ad71c373e48435e/Classification/runs/ICP_cifar100_resnet20/events.out.tfevents.1556005874.MAC-U3S2 -------------------------------------------------------------------------------- /Classification/runs/ICP_cifar100_vgg16/events.out.tfevents.1555905801.MAC-U3S2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hujiecpp/InformationCompetingProcess/d5ce61f8af647ee15a8ce55e7ad71c373e48435e/Classification/runs/ICP_cifar100_vgg16/events.out.tfevents.1555905801.MAC-U3S2 -------------------------------------------------------------------------------- /Classification/runs/ICP_cifar10_densenet40/events.out.tfevents.1555551815.MAC-U3S2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hujiecpp/InformationCompetingProcess/d5ce61f8af647ee15a8ce55e7ad71c373e48435e/Classification/runs/ICP_cifar10_densenet40/events.out.tfevents.1555551815.MAC-U3S2 -------------------------------------------------------------------------------- /Classification/runs/ICP_cifar10_googlenet/events.out.tfevents.1556112349.MAC-U3S2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hujiecpp/InformationCompetingProcess/d5ce61f8af647ee15a8ce55e7ad71c373e48435e/Classification/runs/ICP_cifar10_googlenet/events.out.tfevents.1556112349.MAC-U3S2 -------------------------------------------------------------------------------- /Classification/runs/ICP_cifar10_resnet20/events.out.tfevents.1555501545.MAC-U3S2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hujiecpp/InformationCompetingProcess/d5ce61f8af647ee15a8ce55e7ad71c373e48435e/Classification/runs/ICP_cifar10_resnet20/events.out.tfevents.1555501545.MAC-U3S2 -------------------------------------------------------------------------------- /Classification/runs/ICP_cifar10_resnet20/events.out.tfevents.1569207192.MAC-U3S2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hujiecpp/InformationCompetingProcess/d5ce61f8af647ee15a8ce55e7ad71c373e48435e/Classification/runs/ICP_cifar10_resnet20/events.out.tfevents.1569207192.MAC-U3S2 -------------------------------------------------------------------------------- /Classification/runs/ICP_cifar10_resnet20/events.out.tfevents.1569214119.MAC-U3S2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hujiecpp/InformationCompetingProcess/d5ce61f8af647ee15a8ce55e7ad71c373e48435e/Classification/runs/ICP_cifar10_resnet20/events.out.tfevents.1569214119.MAC-U3S2 -------------------------------------------------------------------------------- /Classification/runs/ICP_cifar10_vgg16/events.out.tfevents.1555379025.MAC-U3S2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hujiecpp/InformationCompetingProcess/d5ce61f8af647ee15a8ce55e7ad71c373e48435e/Classification/runs/ICP_cifar10_vgg16/events.out.tfevents.1555379025.MAC-U3S2 -------------------------------------------------------------------------------- /Classification/scripts/Test_cifar100_densenet40.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --dataset cifar100 \ 3 | --model densenet40 \ 4 | --train False 5 | -------------------------------------------------------------------------------- /Classification/scripts/Test_cifar100_googlenet.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1 python main.py \ 2 | --dataset cifar100 \ 3 | --model googlenet \ 4 | --train False 5 | -------------------------------------------------------------------------------- /Classification/scripts/Test_cifar100_resnet20.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --dataset cifar100 \ 3 | --model resnet20 \ 4 | --train False 5 | -------------------------------------------------------------------------------- /Classification/scripts/Test_cifar100_vgg16.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --dataset cifar100 \ 3 | --model vgg16 \ 4 | --train False 5 | -------------------------------------------------------------------------------- /Classification/scripts/Test_cifar10_densenet40.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --dataset cifar10 \ 3 | --model densenet40 \ 4 | --train False 5 | -------------------------------------------------------------------------------- /Classification/scripts/Test_cifar10_googlenet.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1 python main.py \ 2 | --dataset cifar10 \ 3 | --model googlenet \ 4 | --train False 5 | -------------------------------------------------------------------------------- /Classification/scripts/Test_cifar10_resnet20.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --dataset cifar10 \ 3 | --model resnet20 \ 4 | --train False 5 | -------------------------------------------------------------------------------- /Classification/scripts/Test_cifar10_vgg16.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --dataset cifar10 \ 3 | --model vgg16 \ 4 | --train False 5 | -------------------------------------------------------------------------------- /Classification/scripts/Train_cifar100_densenet40.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --dataset cifar100 \ 3 | --model densenet40 \ 4 | --epoch 300 \ 5 | --gamma 0.001 --alpha 0.001 --beta 0.0001 --rec 0.1 \ 6 | --lr_decay_epochs 150 225 7 | -------------------------------------------------------------------------------- /Classification/scripts/Train_cifar100_googlenet.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1,2 python main.py \ 2 | --dataset cifar100 \ 3 | --model googlenet \ 4 | --epoch 200 \ 5 | --gamma 0.01 --alpha 0.01 --beta 0.001 --rec 0.1 \ 6 | --lr_decay_epochs 100 150 7 | -------------------------------------------------------------------------------- /Classification/scripts/Train_cifar100_resnet20.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --dataset cifar100 \ 3 | --model resnet20 \ 4 | --epoch 200 \ 5 | --gamma 0.01 --alpha 0.01 --beta 0.001 --rec 0.1 \ 6 | --lr_decay_epochs 100 150 7 | -------------------------------------------------------------------------------- /Classification/scripts/Train_cifar100_vgg16.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --dataset cifar100 \ 3 | --model vgg16 \ 4 | --epoch 200 \ 5 | --gamma 0.001 --alpha 0.001 --beta 0.0001 --rec 0.1 \ 6 | --lr_decay_epochs 100 150 7 | -------------------------------------------------------------------------------- /Classification/scripts/Train_cifar10_densenet40.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --dataset cifar10 \ 3 | --model densenet40 \ 4 | --epoch 300 \ 5 | --lr_decay_epochs 150 225 6 | -------------------------------------------------------------------------------- /Classification/scripts/Train_cifar10_googlenet.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1,2 python main.py \ 2 | --dataset cifar10 \ 3 | --model googlenet \ 4 | --epoch 200 \ 5 | --lr_decay_epochs 100 150 6 | -------------------------------------------------------------------------------- /Classification/scripts/Train_cifar10_resnet20.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --dataset cifar10 \ 3 | --model resnet20 \ 4 | --epoch 200 \ 5 | --lr_decay_epochs 100 150 6 | -------------------------------------------------------------------------------- /Classification/scripts/Train_cifar10_vgg16.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --dataset cifar10 \ 3 | --model vgg16 \ 4 | --epoch 90 \ 5 | --lr_decay_epochs 30 60 6 | -------------------------------------------------------------------------------- /Classification/utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | 14 | def str2bool(v): 15 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 16 | return True 17 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 18 | return False 19 | else: 20 | raise argparse.ArgumentTypeError('Boolean value expected.') 21 | 22 | 23 | def get_mean_and_std(dataset): 24 | '''Compute the mean and std value of dataset.''' 25 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 26 | mean = torch.zeros(3) 27 | std = torch.zeros(3) 28 | print('==> Computing mean and std..') 29 | for inputs, targets in dataloader: 30 | for i in range(3): 31 | mean[i] += inputs[:,i,:,:].mean() 32 | std[i] += inputs[:,i,:,:].std() 33 | mean.div_(len(dataset)) 34 | std.div_(len(dataset)) 35 | return mean, std 36 | 37 | def init_params(net): 38 | '''Init layer parameters.''' 39 | for m in net.modules(): 40 | if isinstance(m, nn.Conv2d): 41 | init.kaiming_normal(m.weight, mode='fan_out') 42 | if m.bias: 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.BatchNorm2d): 45 | init.constant(m.weight, 1) 46 | init.constant(m.bias, 0) 47 | elif isinstance(m, nn.Linear): 48 | init.normal(m.weight, std=1e-3) 49 | if m.bias: 50 | init.constant(m.bias, 0) 51 | 52 | 53 | _, term_width = os.popen('stty size', 'r').read().split() 54 | term_width = int(term_width) 55 | 56 | TOTAL_BAR_LENGTH = 65. 57 | last_time = time.time() 58 | begin_time = last_time 59 | def progress_bar(current, total, msg=None): 60 | global last_time, begin_time 61 | if current == 0: 62 | begin_time = time.time() # Reset for new bar. 63 | 64 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 65 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 66 | 67 | sys.stdout.write(' [') 68 | for i in range(cur_len): 69 | sys.stdout.write('=') 70 | sys.stdout.write('>') 71 | for i in range(rest_len): 72 | sys.stdout.write('.') 73 | sys.stdout.write(']') 74 | 75 | cur_time = time.time() 76 | step_time = cur_time - last_time 77 | last_time = cur_time 78 | tot_time = cur_time - begin_time 79 | 80 | L = [] 81 | L.append(' Step: %s' % format_time(step_time)) 82 | L.append(' | Tot: %s' % format_time(tot_time)) 83 | if msg: 84 | L.append(' | ' + msg) 85 | 86 | msg = ''.join(L) 87 | sys.stdout.write(msg) 88 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 89 | sys.stdout.write(' ') 90 | 91 | # Go back to the center of the bar. 92 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 93 | sys.stdout.write('\b') 94 | sys.stdout.write(' %d/%d ' % (current+1, total)) 95 | 96 | if current < total-1: 97 | sys.stdout.write('\r') 98 | else: 99 | sys.stdout.write('\n') 100 | sys.stdout.flush() 101 | 102 | def format_time(seconds): 103 | days = int(seconds / 3600/24) 104 | seconds = seconds - days*3600*24 105 | hours = int(seconds / 3600) 106 | seconds = seconds - hours*3600 107 | minutes = int(seconds / 60) 108 | seconds = seconds - minutes*60 109 | secondsf = int(seconds) 110 | seconds = seconds - secondsf 111 | millis = int(seconds*1000) 112 | 113 | f = '' 114 | i = 1 115 | if days > 0: 116 | f += str(days) + 'D' 117 | i += 1 118 | if hours > 0 and i <= 2: 119 | f += str(hours) + 'h' 120 | i += 1 121 | if minutes > 0 and i <= 2: 122 | f += str(minutes) + 'm' 123 | i += 1 124 | if secondsf > 0 and i <= 2: 125 | f += str(secondsf) + 's' 126 | i += 1 127 | if millis > 0 and i <= 2: 128 | f += str(millis) + 'ms' 129 | i += 1 130 | if f == '': 131 | f = '0ms' 132 | return f 133 | -------------------------------------------------------------------------------- /Disentanglement/MIG_Score/disentanglement_metrics.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import torch 4 | from tqdm import tqdm 5 | from torch.utils.data import DataLoader 6 | 7 | import lib.utils as utils 8 | import lib.datasets as dset 9 | # from metric_helpers.loader import load_model_and_dataset 10 | from metric_helpers.mi_metric import compute_metric_shapes, compute_metric_faces 11 | 12 | import lib.dist as dist 13 | 14 | from model import * 15 | 16 | def estimate_entropies(qz_samples, qz_params, q_dist, n_samples=10000, weights=None): 17 | """Computes the term: 18 | E_{p(x)} E_{q(z|x)} [-log q(z)] 19 | and 20 | E_{p(x)} E_{q(z_j|x)} [-log q(z_j)] 21 | where q(z) = 1/N sum_n=1^N q(z|x_n). 22 | Assumes samples are from q(z|x) for *all* x in the dataset. 23 | Assumes that q(z|x) is factorial ie. q(z|x) = prod_j q(z_j|x). 24 | 25 | Computes numerically stable NLL: 26 | - log q(z) = log N - logsumexp_n=1^N log q(z|x_n) 27 | 28 | Inputs: 29 | ------- 30 | qz_samples (K, N) Variable 31 | qz_params (N, K, nparams) Variable 32 | weights (N) Variable 33 | """ 34 | 35 | # Only take a sample subset of the samples 36 | if weights is None: 37 | qz_samples = qz_samples.index_select(1, torch.randperm(qz_samples.size(1))[:n_samples].cuda()) 38 | else: 39 | sample_inds = torch.multinomial(weights, n_samples, replacement=True) 40 | qz_samples = qz_samples.index_select(1, sample_inds) 41 | 42 | K, S = qz_samples.size() 43 | N, _, nparams = qz_params.size() 44 | assert(nparams == q_dist.nparams) 45 | assert(K == qz_params.size(1)) 46 | 47 | if weights is None: 48 | weights = -math.log(N) 49 | else: 50 | weights = torch.log(weights.view(N, 1, 1) / weights.sum()) 51 | 52 | entropies = torch.zeros(K).cuda() 53 | 54 | pbar = tqdm(total=S) 55 | k = 0 56 | while k < S: 57 | batch_size = min(10, S - k) 58 | logqz_i = q_dist.log_density( 59 | qz_samples.view(1, K, S).expand(N, K, S)[:, :, k:k + batch_size], 60 | qz_params.view(N, K, 1, nparams).expand(N, K, S, nparams)[:, :, k:k + batch_size]) 61 | k += batch_size 62 | 63 | # computes - log q(z_i) summed over minibatch 64 | entropies += - utils.logsumexp(logqz_i + weights, dim=0, keepdim=False).data.sum(1) 65 | pbar.update(batch_size) 66 | pbar.close() 67 | 68 | entropies /= S 69 | 70 | return entropies 71 | 72 | 73 | def mutual_info_metric_shapes(vae, shapes_dataset): 74 | dataset_loader = DataLoader(shapes_dataset, batch_size=1000, num_workers=1, shuffle=False) 75 | 76 | N = len(dataset_loader.dataset) # number of data samples 77 | K = 10 # number of latent variables 78 | nparams = dist.Normal().nparams 79 | vae.eval() 80 | 81 | print('Computing q(z|x) distributions.') 82 | qz_params = torch.Tensor(N, K, nparams) 83 | 84 | n = 0 85 | for xs in dataset_loader: 86 | batch_size = xs.size(0) 87 | xs = xs.view(batch_size, 1, 64, 64).cuda() 88 | 89 | z, mu, logvar, y = vae(xs) 90 | mu = mu.view(batch_size, K, 1) 91 | logvar = logvar.view(batch_size, K, 1) 92 | target = torch.cat([mu, logvar], dim=2) 93 | 94 | qz_params[n:n + batch_size] = target.view(batch_size, K, nparams).data 95 | n += batch_size 96 | 97 | qz_params = qz_params.view(3, 6, 40, 32, 32, K, nparams).cuda() 98 | qz_samples = dist.Normal().sample(params=qz_params) 99 | 100 | print('Estimating marginal entropies.') 101 | # marginal entropies 102 | marginal_entropies = estimate_entropies( 103 | qz_samples.view(N, K).transpose(0, 1), 104 | qz_params.view(N, K, nparams), 105 | dist.Normal()) 106 | 107 | marginal_entropies = marginal_entropies.cpu() 108 | cond_entropies = torch.zeros(4, K) 109 | 110 | print('Estimating conditional entropies for scale.') 111 | for i in range(6): 112 | qz_samples_scale = qz_samples[:, i, :, :, :, :].contiguous() 113 | qz_params_scale = qz_params[:, i, :, :, :, :].contiguous() 114 | 115 | cond_entropies_i = estimate_entropies( 116 | qz_samples_scale.view(N // 6, K).transpose(0, 1), 117 | qz_params_scale.view(N // 6, K, nparams), 118 | dist.Normal()) 119 | 120 | cond_entropies[0] += cond_entropies_i.cpu() / 6 121 | 122 | print('Estimating conditional entropies for orientation.') 123 | for i in range(40): 124 | qz_samples_scale = qz_samples[:, :, i, :, :, :].contiguous() 125 | qz_params_scale = qz_params[:, :, i, :, :, :].contiguous() 126 | 127 | cond_entropies_i = estimate_entropies( 128 | qz_samples_scale.view(N // 40, K).transpose(0, 1), 129 | qz_params_scale.view(N // 40, K, nparams), 130 | dist.Normal()) 131 | 132 | cond_entropies[1] += cond_entropies_i.cpu() / 40 133 | 134 | print('Estimating conditional entropies for pos x.') 135 | for i in range(32): 136 | qz_samples_scale = qz_samples[:, :, :, i, :, :].contiguous() 137 | qz_params_scale = qz_params[:, :, :, i, :, :].contiguous() 138 | 139 | cond_entropies_i = estimate_entropies( 140 | qz_samples_scale.view(N // 32, K).transpose(0, 1), 141 | qz_params_scale.view(N // 32, K, nparams), 142 | dist.Normal()) 143 | 144 | cond_entropies[2] += cond_entropies_i.cpu() / 32 145 | 146 | print('Estimating conditional entropies for pox y.') 147 | for i in range(32): 148 | qz_samples_scale = qz_samples[:, :, :, :, i, :].contiguous() 149 | qz_params_scale = qz_params[:, :, :, :, i, :].contiguous() 150 | 151 | cond_entropies_i = estimate_entropies( 152 | qz_samples_scale.view(N // 32, K).transpose(0, 1), 153 | qz_params_scale.view(N // 32, K, nparams), 154 | dist.Normal()) 155 | 156 | cond_entropies[3] += cond_entropies_i.cpu() / 32 157 | 158 | metric = compute_metric_shapes(marginal_entropies, cond_entropies) 159 | return metric, marginal_entropies, cond_entropies 160 | 161 | 162 | def mutual_info_metric_faces(vae, shapes_dataset): 163 | dataset_loader = DataLoader(shapes_dataset, batch_size=1000, num_workers=1, shuffle=False) 164 | 165 | N = len(dataset_loader.dataset) # number of data samples 166 | K = 10 # number of latent variables 167 | nparams = dist.Normal().nparams 168 | vae.eval() 169 | 170 | print('Computing q(z|x) distributions.') 171 | qz_params = torch.Tensor(N, K, nparams) 172 | 173 | n = 0 174 | for xs in dataset_loader: 175 | batch_size = xs.size(0) 176 | xs = xs.view(batch_size, 1, 64, 64).cuda() 177 | 178 | z, mu, logvar, y = vae(xs) 179 | mu = mu.view(batch_size, K, 1) 180 | logvar = logvar.view(batch_size, K, 1) 181 | target = torch.cat([mu, logvar], dim=2) 182 | 183 | qz_params[n:n + batch_size] = target.view(batch_size, K, nparams).data 184 | n += batch_size 185 | 186 | qz_params = qz_params.view(50, 21, 11, 11, K, nparams).cuda() 187 | qz_samples = dist.Normal().sample(params=qz_params) 188 | 189 | print('Estimating marginal entropies.') 190 | # marginal entropies 191 | marginal_entropies = estimate_entropies( 192 | qz_samples.view(N, K).transpose(0, 1), 193 | qz_params.view(N, K, nparams), 194 | dist.Normal()) 195 | 196 | marginal_entropies = marginal_entropies.cpu() 197 | cond_entropies = torch.zeros(3, K) 198 | 199 | print('Estimating conditional entropies for azimuth.') 200 | for i in range(21): 201 | qz_samples_pose_az = qz_samples[:, i, :, :, :].contiguous() 202 | qz_params_pose_az = qz_params[:, i, :, :, :].contiguous() 203 | 204 | cond_entropies_i = estimate_entropies( 205 | qz_samples_pose_az.view(N // 21, K).transpose(0, 1), 206 | qz_params_pose_az.view(N // 21, K, nparams), 207 | dist.Normal()) 208 | 209 | cond_entropies[0] += cond_entropies_i.cpu() / 21 210 | 211 | print('Estimating conditional entropies for elevation.') 212 | for i in range(11): 213 | qz_samples_pose_el = qz_samples[:, :, i, :, :].contiguous() 214 | qz_params_pose_el = qz_params[:, :, i, :, :].contiguous() 215 | 216 | cond_entropies_i = estimate_entropies( 217 | qz_samples_pose_el.view(N // 11, K).transpose(0, 1), 218 | qz_params_pose_el.view(N // 11, K, nparams), 219 | dist.Normal()) 220 | 221 | cond_entropies[1] += cond_entropies_i.cpu() / 11 222 | 223 | print('Estimating conditional entropies for lighting.') 224 | for i in range(11): 225 | qz_samples_lighting = qz_samples[:, :, :, i, :].contiguous() 226 | qz_params_lighting = qz_params[:, :, :, i, :].contiguous() 227 | 228 | cond_entropies_i = estimate_entropies( 229 | qz_samples_lighting.view(N // 11, K).transpose(0, 1), 230 | qz_params_lighting.view(N // 11, K, nparams), 231 | dist.Normal()) 232 | 233 | cond_entropies[2] += cond_entropies_i.cpu() / 11 234 | 235 | metric = compute_metric_faces(marginal_entropies, cond_entropies) 236 | return metric, marginal_entropies, cond_entropies 237 | 238 | def setup_data_loaders(dataset, batch_size = 2048, use_cuda=True): 239 | if dataset == 'shapes': 240 | train_set = dset.Shapes() 241 | elif dataset == 'faces': 242 | train_set = dset.Faces() 243 | return train_set 244 | 245 | img_dim = 64 246 | nc = 1 247 | z_dim = 10 248 | y_dim = 2 249 | 250 | if __name__ == '__main__': 251 | import argparse 252 | parser = argparse.ArgumentParser() 253 | parser.add_argument('--checkpt', required=True) 254 | parser.add_argument('--name', type=str, default='shapes') 255 | args = parser.parse_args() 256 | 257 | print(args) 258 | 259 | vae_path = "{}".format(args.checkpt) 260 | vae = ICP_Encoder(z_dim = z_dim, y_dim = y_dim, nc = 1).cuda() 261 | vae.load_state_dict(torch.load(vae_path)) 262 | 263 | dataset = setup_data_loaders(args.name) 264 | 265 | metric, marginal_entropies, cond_entropies = eval('mutual_info_metric_' + args.name)(vae, dataset) 266 | 267 | print('MIG: {:.2f}'.format(metric)) 268 | -------------------------------------------------------------------------------- /Disentanglement/MIG_Score/metric_helpers/loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lib.dist as dist 3 | import lib.flows as flows 4 | import vae_quant 5 | 6 | 7 | def load_model_and_dataset(checkpt_filename): 8 | print('Loading model and dataset.') 9 | checkpt = torch.load(checkpt_filename, map_location=lambda storage, loc: storage) 10 | args = checkpt['args'] 11 | state_dict = checkpt['state_dict'] 12 | 13 | # backwards compatibility 14 | if not hasattr(args, 'conv'): 15 | args.conv = False 16 | 17 | if not hasattr(args, 'dist') or args.dist == 'normal': 18 | prior_dist = dist.Normal() 19 | q_dist = dist.Normal() 20 | elif args.dist == 'laplace': 21 | prior_dist = dist.Laplace() 22 | q_dist = dist.Laplace() 23 | elif args.dist == 'flow': 24 | prior_dist = flows.FactorialNormalizingFlow(dim=args.latent_dim, nsteps=32) 25 | q_dist = dist.Normal() 26 | 27 | # model 28 | if hasattr(args, 'ncon'): 29 | # InfoGAN 30 | model = infogan.Model( 31 | args.latent_dim, n_con=args.ncon, n_cat=args.ncat, cat_dim=args.cat_dim, use_cuda=True, conv=args.conv) 32 | model.load_state_dict(state_dict, strict=False) 33 | vae = vae_quant.VAE( 34 | z_dim=args.ncon, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist, conv=args.conv) 35 | vae.encoder = model.encoder 36 | vae.decoder = model.decoder 37 | else: 38 | vae = vae_quant.VAE( 39 | z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist, conv=args.conv) 40 | vae.load_state_dict(state_dict, strict=False) 41 | 42 | # dataset loader 43 | loader = vae_quant.setup_data_loaders(args) 44 | return vae, loader.dataset, args 45 | -------------------------------------------------------------------------------- /Disentanglement/MIG_Score/metric_helpers/mi_metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | metric_name = 'MIG' 4 | 5 | 6 | def MIG(mi_normed): 7 | return torch.mean(mi_normed[:, 0] - mi_normed[:, 1]) 8 | 9 | 10 | def compute_metric_shapes(marginal_entropies, cond_entropies): 11 | factor_entropies = [6, 40, 32, 32] 12 | mutual_infos = marginal_entropies[None] - cond_entropies 13 | mutual_infos = torch.sort(mutual_infos, dim=1, descending=True)[0].clamp(min=0) 14 | mi_normed = mutual_infos / torch.Tensor(factor_entropies).log()[:, None] 15 | metric = eval(metric_name)(mi_normed) 16 | return metric 17 | 18 | 19 | def compute_metric_faces(marginal_entropies, cond_entropies): 20 | factor_entropies = [21, 11, 11] 21 | mutual_infos = marginal_entropies[None] - cond_entropies 22 | mutual_infos = torch.sort(mutual_infos, dim=1, descending=True)[0].clamp(min=0) 23 | mi_normed = mutual_infos / torch.Tensor(factor_entropies).log()[:, None] 24 | metric = eval(metric_name)(mi_normed) 25 | return metric 26 | 27 | -------------------------------------------------------------------------------- /Disentanglement/MIG_Score/model.py: -------------------------------------------------------------------------------- 1 | """model.py""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | #import torch.nn.functional as F 6 | import torch.nn.init as init 7 | from torch.autograd import Variable 8 | 9 | def reparametrize(mu, logvar): 10 | std = logvar.div(2).exp() 11 | eps = Variable(std.data.new(std.size()).normal_()) 12 | return mu + std*eps 13 | 14 | class View(nn.Module): 15 | def __init__(self, size): 16 | super(View, self).__init__() 17 | self.size = size 18 | 19 | def forward(self, tensor): 20 | return tensor.view(self.size) 21 | 22 | class ICP_Encoder(nn.Module): 23 | def __init__(self, z_dim, y_dim, nc): 24 | super(ICP_Encoder, self).__init__() 25 | 26 | self.nc = nc 27 | self.z_dim = z_dim 28 | self.y_dim = y_dim 29 | 30 | sequence = [nn.Conv2d(nc, 32, 4, 2, 1), 31 | nn.ReLU(True), 32 | nn.Conv2d(32, 32, 4, 2, 1), 33 | nn.ReLU(True), 34 | nn.Conv2d(32, 64, 4, 2, 1), 35 | nn.ReLU(True), 36 | nn.Conv2d(64, 64, 4, 2, 1), 37 | nn.ReLU(True), 38 | nn.Conv2d(64, 256, 4, 1), 39 | nn.ReLU(True), 40 | View((-1, 256*1*1))] 41 | 42 | sequence_z = sequence + [nn.Linear(256, z_dim * 2)] 43 | sequence_y = sequence + [nn.Linear(256, y_dim)] 44 | 45 | self.encoder_z = nn.Sequential(*sequence_z) 46 | self.encoder_y = nn.Sequential(*sequence_y) 47 | 48 | def forward(self, x): 49 | distributions = self.encoder_z(x) 50 | mu = distributions[:, : self.z_dim] 51 | logvar = distributions[:, self.z_dim : ] 52 | z = reparametrize(mu, logvar) 53 | y = self.encoder_y(x) 54 | 55 | return z, mu, logvar, y 56 | 57 | class ICP_Decoder(nn.Module): 58 | def __init__(self, dim, nc): 59 | super(ICP_Decoder, self).__init__() 60 | 61 | self.nc = nc 62 | self.dim = dim 63 | 64 | self.decoder = nn.Sequential( 65 | nn.Linear(dim, 256), 66 | View((-1, 256, 1, 1)), 67 | nn.ReLU(True), 68 | nn.ConvTranspose2d(256, 64, 4), 69 | nn.ReLU(True), 70 | nn.ConvTranspose2d(64, 64, 4, 2, 1), 71 | nn.ReLU(True), 72 | nn.ConvTranspose2d(64, 32, 4, 2, 1), 73 | nn.ReLU(True), 74 | nn.ConvTranspose2d(32, 32, 4, 2, 1), 75 | nn.ReLU(True), 76 | nn.ConvTranspose2d(32, nc, 4, 2, 1) 77 | ) 78 | 79 | def forward(self, zy): 80 | 81 | x_rec = self.decoder(zy) 82 | 83 | return x_rec 84 | 85 | class Discriminator(nn.Module): 86 | def __init__(self, y_dim): 87 | super(Discriminator, self).__init__() 88 | self.net = nn.Sequential( 89 | nn.Linear(y_dim * 2, y_dim), 90 | nn.LeakyReLU(0.2, True), 91 | nn.Linear(y_dim, y_dim), 92 | nn.LeakyReLU(0.2, True), 93 | nn.Linear(y_dim, 1), 94 | nn.Sigmoid() 95 | ) 96 | 97 | def forward(self, y): 98 | return self.net(y).squeeze() 99 | 100 | class MLP(nn.Module): 101 | def __init__(self, s_dim, t_dim): 102 | super(MLP, self).__init__() 103 | self.net = nn.Sequential( 104 | nn.Linear(s_dim, s_dim), 105 | nn.LeakyReLU(0.2, True), 106 | nn.Linear(s_dim, t_dim), 107 | nn.LeakyReLU(0.2, True), 108 | nn.Linear(t_dim, t_dim), 109 | nn.ReLU() 110 | ) 111 | 112 | def forward(self, s): 113 | t = self.net(s) 114 | return t 115 | -------------------------------------------------------------------------------- /Disentanglement/README.md: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | - python3 3 | - pytorch 1.0 4 | - tensorboardX 5 | 6 | # Prepare Datasets 7 | The preparing of data is the same as [FactorVAE](https://github.com/1Konny/FactorVAE). 8 | 9 | 1. For dSprites Dataset: 10 | 11 | > chmod +x ./scripts/prepare_data.sh 12 | 13 | > sh scripts/prepare_data.sh dsprites 14 | 15 | 2. For CelebA Dataset([download](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)): 16 | 17 | > download img_align_celeba.zip into ./data/ 18 | 19 | > sh scripts/prepare_data.sh CelebA 20 | 21 | 3. For 3D Faces Dataset, We cannot publicly distribute this due to the [license](https://faces.dmi.unibas.ch/bfm/main.php?nav=1-2&id=downloads). 22 | 23 | # Training 24 | To train a model on dSprites with ICP: 25 | > chmod +x ./scripts/Train_dsprites.sh 26 | 27 | > sh ./scripts/Train_dsprites.sh 28 | 29 | # Testing 30 | To test a trained model on dSprites: 31 | > chmod +x ./scripts/Test_dsprites.sh 32 | 33 | > sh ./scripts/Test_dsprites.sh 34 | 35 | # MIG Score 36 | To evaluate the MIG Score on the dSprites and 3D Faces datasets: 37 | > chmod +x ./scripts/MIG_dsprites.sh ./scripts/MIG_faces.sh 38 | 39 | > sh ./scripts/MIG_dsprites.sh 40 | 41 | > sh ./scripts/MIG_faces.sh 42 | 43 | # Results 44 | The MIG score of ICP on dSprites and 3D Faces: 45 | 46 | | | dSprites | 3D Faces | 47 | | :---: |:--------:|:--------: | 48 | |Beta-VAE |0.22 |0.54 | 49 | |Beta-TCVAE|0.38 |0.62 | 50 | |ICP-ALL |0.33 |0.26 | 51 | |ICP-COM |0.20 |0.57 | 52 | |**ICP** |**0.48** |**0.73** | 53 | 54 | ICP-ALL denotes the result of ICP without all the information constraints, ICP-COM denotes the results of ICP without the competing constraints. 55 | 56 | # Trained Models 57 | The trained models of getting our paper's results can be download by [Baidu Netdisk](https://pan.baidu.com/s/1JLQrOvVWbWIXzu_A2l4Ccw) (Password: vd3i), or [Google Drive](https://drive.google.com/drive/folders/19mBHxAVYALPzIQLvvL0uU9-XMLEttBc6?usp=sharing). 58 | -------------------------------------------------------------------------------- /Disentanglement/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision.datasets import ImageFolder 6 | from torchvision import transforms 7 | 8 | class CustomImageFolder(ImageFolder): 9 | def __init__(self, root, transform=None): 10 | super(CustomImageFolder, self).__init__(root, transform) 11 | 12 | def __getitem__(self, index): 13 | path = self.imgs[index][0] 14 | img = self.loader(path) 15 | if self.transform is not None: 16 | img = self.transform(img) 17 | 18 | return img 19 | 20 | 21 | class CustomTensorDataset(Dataset): 22 | def __init__(self, data_tensor): 23 | self.data_tensor = data_tensor 24 | 25 | def __getitem__(self, index): 26 | return self.data_tensor[index] 27 | 28 | def __len__(self): 29 | return self.data_tensor.size(0) 30 | 31 | 32 | def return_data(args): 33 | name = args.dataset 34 | dset_dir = args.dset_dir 35 | batch_size = args.batch_size 36 | num_workers = args.num_workers 37 | image_size = args.image_size 38 | 39 | if name.lower() == 'celeba': 40 | root = os.path.join(dset_dir, 'CelebA') 41 | transform = transforms.Compose([ 42 | transforms.Resize((image_size, image_size)), 43 | transforms.ToTensor(),]) 44 | train_kwargs = {'root':root, 'transform':transform} 45 | dset = CustomImageFolder 46 | 47 | elif name.lower() == 'dsprites': 48 | root = os.path.join(dset_dir, 'dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz') 49 | data = np.load(root, encoding='bytes') 50 | data = torch.from_numpy(data['imgs']).unsqueeze(1).float() 51 | train_kwargs = {'data_tensor':data} 52 | dset = CustomTensorDataset 53 | 54 | elif name.lower() == 'faces': 55 | root = os.path.join(dset_dir, 'faces/basel_face_renders.pth') 56 | data = torch.load(root).float().div(255).view(-1, 1, 64, 64) 57 | train_kwargs = {'data_tensor':data} 58 | dset = CustomTensorDataset 59 | 60 | else: 61 | raise NotImplementedError 62 | 63 | 64 | train_data = dset(**train_kwargs) 65 | train_loader = DataLoader(train_data, 66 | batch_size=batch_size, 67 | shuffle=True, 68 | num_workers=num_workers, 69 | pin_memory=True, 70 | drop_last=True) 71 | 72 | data_loader = train_loader 73 | 74 | return data_loader 75 | -------------------------------------------------------------------------------- /Disentanglement/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from solver import Solver 7 | from utils import str2bool 8 | 9 | torch.backends.cudnn.enabled = True 10 | torch.backends.cudnn.benchmark = True 11 | 12 | 13 | def main(args): 14 | seed = args.seed 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | np.random.seed(seed) 18 | 19 | net = Solver(args) 20 | 21 | if args.train: 22 | net.train() 23 | else: 24 | net.test() 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser(description='Information Competing Process (ICP)') 29 | 30 | parser.add_argument('--train', default=True, type=str2bool, help='train or test') 31 | parser.add_argument('--seed', default=2, type=int, help='random seed') 32 | parser.add_argument('--cuda', default=True, type=str2bool, help='enable cuda') 33 | parser.add_argument('--max_iter', default=1e6, type=float, help='maximum training iteration') 34 | parser.add_argument('--batch_size', default=64, type=int, help='batch size') 35 | 36 | parser.add_argument('--z_dim', default=10, type=int, help='dimension of the representation z') 37 | parser.add_argument('--y_dim', default=2, type=int, help='dimension of the representation y') 38 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') 39 | 40 | parser.add_argument('--gamma', default=1, type=float, help='Compete - MLP') 41 | parser.add_argument('--alpha', default=1, type=float, help='Max - DIS') 42 | parser.add_argument('--beta', default=4, type=float, help='MIN - KL') 43 | parser.add_argument('--rec', default=1, type=float, help='Synergy - REC') 44 | 45 | parser.add_argument('--dset_dir', default='data', type=str, help='dataset directory') 46 | parser.add_argument('--dataset', default='CelebA', type=str, help='dataset name: [CelebA, faces, dsprites]') 47 | parser.add_argument('--image_size', default=64, type=int, help='image size. [64, 128].') 48 | parser.add_argument('--num_workers', default=4, type=int, help='dataloader num_workers') 49 | 50 | parser.add_argument('--save_name', default='main', type=str, help='output name.') 51 | parser.add_argument('--save_output', default=True, type=str2bool, help='save traverse images and gif.') 52 | 53 | parser.add_argument('--display_step', default=10000, type=int, help='number of iterations after which loss data is printed.') 54 | parser.add_argument('--save_step', default=10000, type=int, help='number of iterations after which a checkpoint is saved.') 55 | 56 | parser.add_argument('--ckpt_name', default='last', type=str, help='load previous checkpoint. insert checkpoint filename.') 57 | parser.add_argument('--global_iter', default=0, type=float, help='number of iterations continue to train.') 58 | 59 | args = parser.parse_args() 60 | 61 | main(args) 62 | -------------------------------------------------------------------------------- /Disentanglement/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | 5 | def reparametrize(mu, logvar): 6 | std = logvar.div(2).exp() 7 | eps = std.data.new(std.size()).normal_() 8 | return mu + std*eps 9 | 10 | 11 | class View(nn.Module): 12 | def __init__(self, size): 13 | super(View, self).__init__() 14 | self.size = size 15 | 16 | def forward(self, tensor): 17 | return tensor.view(self.size) 18 | 19 | 20 | class Discriminator(nn.Module): 21 | def __init__(self, y_dim): 22 | super(Discriminator, self).__init__() 23 | self.net = nn.Sequential( 24 | nn.Linear(y_dim * 2, y_dim), 25 | nn.LeakyReLU(0.2, True), 26 | nn.Linear(y_dim, y_dim), 27 | nn.LeakyReLU(0.2, True), 28 | nn.Linear(y_dim, 1), 29 | nn.Sigmoid() 30 | ) 31 | 32 | def forward(self, y): 33 | return self.net(y).squeeze() 34 | 35 | 36 | class MLP(nn.Module): 37 | def __init__(self, s_dim, t_dim): 38 | super(MLP, self).__init__() 39 | self.net = nn.Sequential( 40 | nn.Linear(s_dim, s_dim), 41 | nn.LeakyReLU(0.2, True), 42 | nn.Linear(s_dim, t_dim), 43 | nn.LeakyReLU(0.2, True), 44 | nn.Linear(t_dim, t_dim), 45 | nn.ReLU() 46 | ) 47 | 48 | def forward(self, s): 49 | t = self.net(s) 50 | return t 51 | 52 | # -------------- For dSprites, 3D faces datasets -------------- 53 | class ICP_Encoder(nn.Module): 54 | def __init__(self, z_dim, y_dim, nc): 55 | super(ICP_Encoder, self).__init__() 56 | 57 | self.nc = nc 58 | self.z_dim = z_dim 59 | self.y_dim = y_dim 60 | 61 | sequence = [nn.Conv2d(nc, 32, 4, 2, 1), 62 | nn.ReLU(True), 63 | nn.Conv2d(32, 32, 4, 2, 1), 64 | nn.ReLU(True), 65 | nn.Conv2d(32, 64, 4, 2, 1), 66 | nn.ReLU(True), 67 | nn.Conv2d(64, 64, 4, 2, 1), 68 | nn.ReLU(True), 69 | nn.Conv2d(64, 256, 4, 1), 70 | nn.ReLU(True), 71 | View((-1, 256*1*1))] 72 | 73 | sequence_z = sequence + [nn.Linear(256, z_dim * 2)] 74 | sequence_y = sequence + [nn.Linear(256, y_dim)] 75 | 76 | self.encoder_z = nn.Sequential(*sequence_z) 77 | self.encoder_y = nn.Sequential(*sequence_y) 78 | 79 | def forward(self, x): 80 | distributions = self.encoder_z(x) 81 | mu = distributions[:, : self.z_dim] 82 | logvar = distributions[:, self.z_dim : ] 83 | z = reparametrize(mu, logvar) 84 | y = self.encoder_y(x) 85 | 86 | return z, mu, logvar, y 87 | 88 | class ICP_Decoder(nn.Module): 89 | def __init__(self, dim, nc): 90 | super(ICP_Decoder, self).__init__() 91 | 92 | self.nc = nc 93 | self.dim = dim 94 | 95 | self.decoder = nn.Sequential( 96 | nn.Linear(dim, 256), 97 | View((-1, 256, 1, 1)), 98 | nn.ReLU(True), 99 | nn.ConvTranspose2d(256, 64, 4), 100 | nn.ReLU(True), 101 | nn.ConvTranspose2d(64, 64, 4, 2, 1), 102 | nn.ReLU(True), 103 | nn.ConvTranspose2d(64, 32, 4, 2, 1), 104 | nn.ReLU(True), 105 | nn.ConvTranspose2d(32, 32, 4, 2, 1), 106 | nn.ReLU(True), 107 | nn.ConvTranspose2d(32, nc, 4, 2, 1) 108 | ) 109 | 110 | def forward(self, zy): 111 | x_rec = self.decoder(zy) 112 | 113 | return x_rec 114 | 115 | # -------------- For CelebA dataset -------------- 116 | class ICP_Encoder_Big(nn.Module): 117 | def __init__(self, z_dim, y_dim, nc, img_dim = 128): 118 | super(ICP_Encoder_Big, self).__init__() 119 | 120 | self.img_dim = img_dim 121 | self.nc = nc 122 | self.z_dim = z_dim 123 | self.y_dim = y_dim 124 | self.layer_num = int(np.log2(img_dim)) - 3 125 | self.max_channel_num = img_dim * 8 126 | self.f_num = 4 127 | 128 | # Encoder 129 | sequence = [] 130 | for i in range(self.layer_num + 1): 131 | now_channel_num = self.max_channel_num // 2**(self.layer_num - i) 132 | 133 | if i == 0: 134 | sequence += [nn.Conv2d(in_channels = self.nc, out_channels = now_channel_num, 135 | kernel_size = 4, stride = 2, padding = 1),] 136 | else: 137 | sequence += [nn.Conv2d(in_channels = pre_channel_num, out_channels = now_channel_num, 138 | kernel_size = 4, stride = 2, padding = 1), 139 | nn.BatchNorm2d(now_channel_num),] 140 | sequence += [nn.LeakyReLU(0.2, True)] 141 | 142 | pre_channel_num = now_channel_num 143 | 144 | sequence_z = sequence + [View((-1, self.f_num * self.f_num * self.max_channel_num)), 145 | nn.Linear(self.f_num * self.f_num * self.max_channel_num, self.z_dim * 2)] 146 | 147 | sequence_y = sequence + [View((-1, self.f_num * self.f_num * self.max_channel_num)), 148 | nn.Linear(self.f_num * self.f_num * self.max_channel_num, self.y_dim)] 149 | 150 | self.encoder_z = nn.Sequential(*sequence_z) 151 | self.encoder_y = nn.Sequential(*sequence_y) 152 | 153 | def forward(self, x): 154 | # Encode 155 | distributions = self.encoder_z(x) 156 | mu = distributions[:, : self.z_dim] 157 | logvar = distributions[:, self.z_dim : ] 158 | z = reparametrize(mu, logvar) 159 | y = self.encoder_y(x) 160 | 161 | return z, mu, logvar, y 162 | 163 | class ICP_Decoder_Big(nn.Module): 164 | def __init__(self, dim, nc, img_dim = 128): 165 | super(ICP_Decoder_Big, self).__init__() 166 | 167 | self.img_dim = img_dim 168 | self.nc = nc 169 | self.dim = dim 170 | self.layer_num = int(np.log2(img_dim)) - 3 171 | self.max_channel_num = img_dim * 8 172 | self.f_num = 4 173 | 174 | # Decoder 175 | sequence = [nn.Linear(self.dim, self.f_num * self.f_num * self.max_channel_num), 176 | nn.ReLU(True), 177 | View((-1, self.max_channel_num, self.f_num, self.f_num)),] 178 | pre_channel_num = self.max_channel_num 179 | for i in range(self.layer_num): 180 | now_channel_num = self.max_channel_num // 2**(i + 1) 181 | 182 | sequence += [nn.ConvTranspose2d(in_channels = pre_channel_num, out_channels = now_channel_num, 183 | kernel_size = 4, stride = 2, padding = 1), 184 | nn.BatchNorm2d(now_channel_num), 185 | nn.ReLU(True),] 186 | 187 | pre_channel_num = now_channel_num 188 | 189 | sequence += [nn.ConvTranspose2d(in_channels = now_channel_num, out_channels = self.nc, 190 | kernel_size = 4, stride = 2, padding = 1), 191 | # nn.Tanh() 192 | ] 193 | self.decoder = nn.Sequential(*sequence) 194 | 195 | def forward(self, z): 196 | x_rec = self.decoder(z) 197 | 198 | return x_rec 199 | -------------------------------------------------------------------------------- /Disentanglement/scripts/MIG_dsprites.sh: -------------------------------------------------------------------------------- 1 | # 2 | python ./MIG_Score/disentanglement_metrics.py \ 3 | --checkpt ./checkpoints/dsprites_ICP/netE_last.pth \ 4 | --name shapes 5 | -------------------------------------------------------------------------------- /Disentanglement/scripts/MIG_faces.sh: -------------------------------------------------------------------------------- 1 | # 2 | python ./MIG_Score/disentanglement_metrics.py \ 3 | --checkpt ./checkpoints/faces_ICP/netE_last.pth \ 4 | --name faces 5 | -------------------------------------------------------------------------------- /Disentanglement/scripts/Test_celeba.sh: -------------------------------------------------------------------------------- 1 | # 2 | CUDA_VISIBLE_DEVICES=0 python main.py \ 3 | --train False \ 4 | --dataset celebA \ 5 | --z_dim 32 --y_dim 32 \ 6 | --image_size 128 \ 7 | --save_name celebA_ICP 8 | -------------------------------------------------------------------------------- /Disentanglement/scripts/Test_dsprites.sh: -------------------------------------------------------------------------------- 1 | # 2 | CUDA_VISIBLE_DEVICES=0 python main.py \ 3 | --train False \ 4 | --dataset dsprites \ 5 | --z_dim 10 --y_dim 2 \ 6 | --save_name dsprites_ICP 7 | -------------------------------------------------------------------------------- /Disentanglement/scripts/Test_faces.sh: -------------------------------------------------------------------------------- 1 | # 2 | CUDA_VISIBLE_DEVICES=0 python main.py \ 3 | --train False \ 4 | --dataset faces \ 5 | --z_dim 10 --y_dim 2 \ 6 | --save_name faces_ICP 7 | -------------------------------------------------------------------------------- /Disentanglement/scripts/Train_celeba.sh: -------------------------------------------------------------------------------- 1 | # 2 | CUDA_VISIBLE_DEVICES=0 python main.py \ 3 | --dataset celebA \ 4 | --lr 5e-4 \ 5 | --batch_size 64 \ 6 | --z_dim 32 --y_dim 32 \ 7 | --max_iter 1.5e6 \ 8 | --image_size 128 \ 9 | --gamma 1 --alpha 5 --beta 5 --rec 1 \ 10 | --save_name celebA_ICP -------------------------------------------------------------------------------- /Disentanglement/scripts/Train_dsprites.sh: -------------------------------------------------------------------------------- 1 | # 2 | CUDA_VISIBLE_DEVICES=0 python main.py \ 3 | --dataset dsprites \ 4 | --lr 5e-4 \ 5 | --batch_size 64 \ 6 | --z_dim 10 --y_dim 2 \ 7 | --max_iter 2.5e5 \ 8 | --gamma 1 --alpha 1 --beta 10 --rec 0.5 \ 9 | --save_name dsprites_ICP 10 | 11 | CUDA_VISIBLE_DEVICES=0 python main.py \ 12 | --dataset dsprites \ 13 | --lr 5e-4 \ 14 | --batch_size 64 \ 15 | --z_dim 10 --y_dim 2 \ 16 | --max_iter 3.5e5 \ 17 | --gamma 1 --alpha 1 --beta 10 --rec 1 \ 18 | --global_iter 2.5e5 \ 19 | --ckpt_name 250000 \ 20 | --save_name dsprites_ICP -------------------------------------------------------------------------------- /Disentanglement/scripts/Train_faces.sh: -------------------------------------------------------------------------------- 1 | # 2 | CUDA_VISIBLE_DEVICES=0 python main.py \ 3 | --dataset faces \ 4 | --lr 5e-4 \ 5 | --batch_size 64 \ 6 | --z_dim 10 --y_dim 2 \ 7 | --max_iter 1e5 \ 8 | --gamma 1 --alpha 1 --beta 4 --rec 1 \ 9 | --save_name faces_ICP -------------------------------------------------------------------------------- /Disentanglement/scripts/prepare_data.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data 2 | cd data 3 | 4 | if [ "$1" = "dsprites" ]; then 5 | git clone https://github.com/deepmind/dsprites-dataset.git 6 | cd dsprites-dataset 7 | rm -rf .git* *.md LICENSE *.ipynb *.gif *.hdf5 8 | 9 | elif [ "$1" = "CelebA" ]; then 10 | unzip img_align_celeba.zip 11 | mkdir CelebA 12 | mv img_align_celeba CelebA 13 | fi 14 | -------------------------------------------------------------------------------- /Disentanglement/solver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from torchvision.utils import save_image 6 | 7 | from utils import cuda, grid2gif 8 | from model import ICP_Encoder, ICP_Decoder, Discriminator, MLP, ICP_Encoder_Big, ICP_Decoder_Big 9 | from dataset import return_data 10 | import numpy as np 11 | import random 12 | import torchvision.utils as vutils 13 | 14 | def reconstruction_loss(x, x_recon, distribution): 15 | batch_size = x.size(0) 16 | assert batch_size != 0 17 | 18 | if distribution == 'bernoulli': 19 | recon_loss = F.binary_cross_entropy_with_logits(x_recon, x, reduction='sum').div(batch_size) 20 | elif distribution == 'gaussian': 21 | recon_loss = F.mse_loss(x_recon, x, reduction='sum').div(batch_size) 22 | else: 23 | recon_loss = None 24 | 25 | return recon_loss 26 | 27 | 28 | def kl_divergence(mu, logvar): 29 | batch_size = mu.size(0) 30 | assert batch_size != 0 31 | if mu.data.ndimension() == 4: 32 | mu = mu.view(mu.size(0), mu.size(1)) 33 | if logvar.data.ndimension() == 4: 34 | logvar = logvar.view(logvar.size(0), logvar.size(1)) 35 | 36 | klds = -0.5*(1 + logvar - mu.pow(2) - logvar.exp()) 37 | total_kld = klds.sum(1).mean(0, True) 38 | dimension_wise_kld = klds.mean(0) 39 | mean_kld = klds.mean(1).mean(0, True) 40 | 41 | return total_kld, dimension_wise_kld, mean_kld 42 | 43 | class Solver(object): 44 | def __init__(self, args): 45 | self.use_cuda = args.cuda and torch.cuda.is_available() 46 | self.max_iter = args.max_iter 47 | self.global_iter = args.global_iter 48 | 49 | self.z_dim = args.z_dim 50 | self.y_dim = args.y_dim 51 | 52 | self.gamma = args.gamma # MLP 53 | self.alpha = args.alpha # DIS 54 | self.beta = args.beta # KL 55 | self.rec = args.rec # REC 56 | 57 | self.image_size = args.image_size 58 | self.lr = args.lr 59 | 60 | if args.dataset.lower() == 'dsprites' or args.dataset.lower() == 'faces': 61 | self.nc = 1 62 | self.decoder_dist = 'bernoulli' 63 | 64 | self.netE = cuda(ICP_Encoder(z_dim = self.z_dim, y_dim = self.y_dim, nc = self.nc), self.use_cuda) 65 | self.netG_Less = cuda(ICP_Decoder(dim = self.z_dim, nc = self.nc), self.use_cuda) 66 | self.netG_More = cuda(ICP_Decoder(dim = self.y_dim, nc = self.nc), self.use_cuda) 67 | self.netG_Total = cuda(ICP_Decoder(dim = self.z_dim + self.y_dim, nc = self.nc), self.use_cuda) 68 | self.netD = cuda(Discriminator(y_dim = self.y_dim), self.use_cuda) 69 | self.netZ2Y = cuda(MLP(s_dim = self.z_dim, t_dim = self.y_dim), self.use_cuda) 70 | self.netY2Z = cuda(MLP(s_dim = self.y_dim, t_dim = self.z_dim), self.use_cuda) 71 | 72 | elif args.dataset.lower() == 'celeba': 73 | self.nc = 3 74 | self.decoder_dist = 'gaussian' 75 | 76 | self.netE = cuda(ICP_Encoder_Big(z_dim = self.z_dim, y_dim = self.y_dim, nc = self.nc), self.use_cuda) 77 | self.netG_Less = cuda(ICP_Decoder_Big(dim = self.z_dim, nc = self.nc), self.use_cuda) 78 | self.netG_More = cuda(ICP_Decoder_Big(dim = self.y_dim, nc = self.nc), self.use_cuda) 79 | self.netG_Total = cuda(ICP_Decoder_Big(dim = self.z_dim + self.y_dim, nc = self.nc), self.use_cuda) 80 | self.netD = cuda(Discriminator(y_dim = self.y_dim), self.use_cuda) 81 | self.netZ2Y = cuda(MLP(s_dim = self.z_dim, t_dim = self.y_dim), self.use_cuda) 82 | self.netY2Z = cuda(MLP(s_dim = self.y_dim, t_dim = self.z_dim), self.use_cuda) 83 | else: 84 | raise NotImplementedError 85 | 86 | self.optimizerG = optim.Adam([{'params' : self.netE.parameters()}, 87 | {'params' : self.netG_Less.parameters()}, 88 | {'params' : self.netG_More.parameters()}, 89 | {'params' : self.netG_Total.parameters()}], lr = self.lr, betas = (0.9, 0.999)) 90 | self.optimizerD = optim.Adam(self.netD.parameters(), lr = self.lr, betas = (0.9, 0.999)) 91 | self.optimizerMLP = optim.Adam([{'params' : self.netZ2Y.parameters()}, 92 | {'params' : self.netY2Z.parameters()}], lr = self.lr, betas = (0.9, 0.999)) 93 | 94 | 95 | self.save_name = args.save_name 96 | self.ckpt_dir = os.path.join('./checkpoints/', args.save_name) 97 | if not os.path.exists(self.ckpt_dir): 98 | os.makedirs(self.ckpt_dir, exist_ok=True) 99 | self.ckpt_name = args.ckpt_name 100 | if args.train and self.ckpt_name is not None: 101 | self.load_checkpoint(self.ckpt_name) 102 | self.save_output = args.save_output 103 | self.output_dir = os.path.join('./trainvis/', args.save_name) 104 | if not os.path.exists(self.output_dir): 105 | os.makedirs(self.output_dir, exist_ok=True) 106 | 107 | self.display_step = args.display_step 108 | self.save_step = args.save_step 109 | 110 | self.dataset = args.dataset 111 | self.batch_size = args.batch_size 112 | self.data_loader = return_data(args) 113 | 114 | def train(self): 115 | self.net_mode(train=True) 116 | out = False 117 | ones = cuda(torch.ones(self.batch_size, dtype = torch.float), self.use_cuda) 118 | zeros = cuda(torch.zeros(self.batch_size, dtype = torch.float), self.use_cuda) 119 | 120 | while not out: 121 | for x in self.data_loader: 122 | self.global_iter += 1 123 | 124 | ### Train ### 125 | x = cuda(x, self.use_cuda) 126 | 127 | ## Update MLP 128 | z, mu, logvar, y = self.netE(x) 129 | rec_y = self.netZ2Y(z) 130 | rec_z = self.netY2Z(y) 131 | 132 | # loss MLP 133 | loss_MLP = self.gamma * (F.mse_loss(rec_z, z.detach(), reduction='sum').div(self.batch_size) \ 134 | + F.mse_loss(rec_y, y.detach(), reduction='sum').div(self.batch_size)) 135 | self.optimizerMLP.zero_grad() 136 | loss_MLP.backward() 137 | self.optimizerMLP.step() 138 | 139 | # loss D 140 | index = np.arange(x.size()[0]) 141 | np.random.shuffle(index) 142 | y_shuffle = y.clone() 143 | y_shuffle = y_shuffle[index, :] 144 | 145 | real_score = self.netD(torch.cat([y.detach(), y.detach()], dim=1)) 146 | fake_score = self.netD(torch.cat([y.detach(), y_shuffle.detach()], dim=1)) 147 | 148 | loss_D = self.alpha * (F.binary_cross_entropy(real_score, ones, reduction='sum').div(self.batch_size) \ 149 | + F.binary_cross_entropy(fake_score, zeros, reduction='sum').div(self.batch_size)) 150 | 151 | self.optimizerD.zero_grad() 152 | loss_D.backward() 153 | self.optimizerD.step() 154 | 155 | ## Update G 156 | z, mu, logvar, y = self.netE(x) 157 | 158 | rec_y = self.netZ2Y(z) 159 | rec_z = self.netY2Z(y) 160 | 161 | rec_less_x = self.netG_Less(z) 162 | rec_more_x = self.netG_More(y) 163 | rec_x = self.netG_Total(torch.cat([z, y], dim=1)) 164 | 165 | # loss MLP 166 | loss_MLP = (F.mse_loss(rec_z, z.detach(), reduction='sum').div(self.batch_size) \ 167 | + F.mse_loss(rec_y, y.detach(), reduction='sum').div(self.batch_size)) 168 | 169 | # loss D 170 | index = np.arange(x.size()[0]) 171 | np.random.shuffle(index) 172 | y_shuffle = y.clone() 173 | y_shuffle = y_shuffle[index, :] 174 | 175 | real_score = self.netD(torch.cat([y, y.detach()], dim=1)) 176 | fake_score = self.netD(torch.cat([y, y_shuffle.detach()], dim=1)) 177 | 178 | loss_D = (F.binary_cross_entropy(real_score, ones, reduction='sum').div(self.batch_size) \ 179 | + F.binary_cross_entropy(fake_score, zeros, reduction='sum').div(self.batch_size)) 180 | 181 | # loss KL 182 | loss_kl, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) 183 | 184 | # loss Rec 185 | loss_less_rec = reconstruction_loss(x, rec_less_x, self.decoder_dist) 186 | loss_more_rec = reconstruction_loss(x, rec_more_x, self.decoder_dist) 187 | loss_total_rec = reconstruction_loss(x, rec_x, self.decoder_dist) 188 | loss_rec = loss_total_rec + loss_less_rec + loss_more_rec 189 | 190 | # total Loss 191 | loss_total = self.rec * loss_rec + self.beta * loss_kl + self.alpha * loss_D - self.gamma * loss_MLP 192 | 193 | self.optimizerG.zero_grad() 194 | loss_total.backward() 195 | self.optimizerG.step() 196 | 197 | 198 | ### Test ### 199 | if self.global_iter % 50 == 0: 200 | print('[Iter-{}]: loss_total: {:.5f}, loss_MLP: {:.5f}, loss_D: {:.5f}, loss_kl: {:.5f}, loss_less_rec: {:.5f}, loss_more_rec: {:.5f}, loss_total_rec: {:.5f}'.format( 201 | self.global_iter, loss_total.item(), loss_MLP.item(), loss_D.item(), loss_kl.item(), 202 | loss_less_rec.item(), loss_more_rec.item(), loss_total_rec.item())) 203 | 204 | if self.global_iter % self.display_step == 0: 205 | if self.save_output: 206 | self.viz_traverse() 207 | 208 | if self.global_iter % self.save_step == 0: 209 | self.save_checkpoint('last') 210 | print('Saved checkpoint(iter:{})'.format(self.global_iter)) 211 | 212 | if self.global_iter % 50000 == 0: 213 | self.save_checkpoint(str(self.global_iter)) 214 | 215 | if self.global_iter >= self.max_iter: 216 | out = True 217 | break 218 | 219 | def test(self, epoch = 'last'): 220 | if not os.path.exists("./testvis/"): 221 | os.mkdir("./testvis/") 222 | 223 | if not os.path.exists("./testvis/{}/".format(self.save_name)): 224 | os.mkdir("./testvis/{}/".format(self.save_name)) 225 | 226 | self.load_checkpoint(epoch) 227 | self.net_mode(train = False) 228 | 229 | limit = 3 230 | inter = 0.2 231 | interpolation = torch.arange(-limit, limit + inter, inter) 232 | print("testing: ", interpolation) 233 | # Random 234 | n_dsets = len(self.data_loader.dataset) 235 | fixed_idxs = [random.randint(1, n_dsets-1),random.randint(1, n_dsets-1), 236 | random.randint(1, n_dsets-1),random.randint(1, n_dsets-1), 237 | random.randint(1, n_dsets-1),random.randint(1, n_dsets-1), 238 | random.randint(1, n_dsets-1),random.randint(1, n_dsets-1), 239 | random.randint(1, n_dsets-1),random.randint(1, n_dsets-1), 240 | random.randint(1, n_dsets-1),random.randint(1, n_dsets-1)] 241 | 242 | for i, fixed_idx in enumerate(fixed_idxs): 243 | fixed_img = cuda(self.data_loader.dataset.__getitem__(fixed_idx), self.use_cuda).unsqueeze(0) 244 | vutils.save_image(fixed_img.cpu().data[:1, ], './testvis/{}/{}.jpg'.format(self.save_name, fixed_idx)) 245 | _, z, _, y = self.netE(fixed_img) 246 | 247 | x_less = self.netG_Less(z) 248 | x_more = self.netG_More(y) 249 | x_rec = self.netG_Total(torch.cat([z, y], dim=1)) 250 | if self.dataset == 'faces': 251 | x_less = torch.sigmoid(x_less) 252 | x_more = torch.sigmoid(x_more) 253 | x_rec = torch.sigmoid(x_rec) 254 | vutils.save_image(x_less.cpu().data[:1, ], './testvis/{}/{}_less.png'.format(self.save_name, fixed_idx)) 255 | vutils.save_image(x_more.cpu().data[:1, ], './testvis/{}/{}_more.png'.format(self.save_name, fixed_idx)) 256 | vutils.save_image(x_rec.cpu().data[:1, ], './testvis/{}/{}_rec.png'.format(self.save_name, fixed_idx)) 257 | 258 | for row in range(self.z_dim): 259 | z_tmp = z.clone() 260 | for j, val in enumerate(interpolation): 261 | z_tmp[:, row] = val 262 | sample = self.netG_Total(torch.cat([z_tmp, y], dim=1)) 263 | if self.dataset == 'faces': 264 | sample = torch.sigmoid(sample) 265 | vutils.save_image(sample.cpu().data[:1, ], './testvis/{}/{}_{}_{}.jpg'.format(self.save_name, fixed_idx, row, j)) 266 | 267 | fixed_img = cuda(self.data_loader.dataset.__getitem__(fixed_idx), self.use_cuda).unsqueeze(0) 268 | vutils.save_image(fixed_img.cpu().data[:1, ], './testvis/{}/{}.jpg'.format(self.save_name, fixed_idx)) 269 | _, z, _, y = self.netE(fixed_img) 270 | 271 | 272 | for row in range(self.y_dim): 273 | y_tmp = y.clone() 274 | for j, val in enumerate(interpolation): 275 | y_tmp[:, row] = val 276 | sample = self.netG_Total(torch.cat([z, y_tmp], dim=1)) 277 | if self.dataset == 'faces': 278 | sample = torch.sigmoid(sample) 279 | vutils.save_image(sample.cpu().data[:1, ], './testvis/{}/{}_{}_{}.jpg'.format(self.save_name, fixed_idx, row + self.z_dim, j)) 280 | print("done!") 281 | 282 | def viz_traverse(self, limit = 3, inter = 2/3, loc = -1): 283 | self.net_mode(train = False) 284 | 285 | decoder = self.netG_Total 286 | encoder = self.netE 287 | interpolation = torch.arange(-limit, limit+0.1, inter) 288 | 289 | n_dsets = len(self.data_loader.dataset) 290 | rand_idx = random.randint(1, n_dsets-1) 291 | 292 | random_img = self.data_loader.dataset.__getitem__(rand_idx) 293 | random_img = cuda(random_img, self.use_cuda).unsqueeze(0) 294 | 295 | _, random_img_z, _, random_img_y = encoder(random_img) 296 | 297 | random_z = cuda(torch.rand(1, self.z_dim), self.use_cuda) 298 | 299 | if self.dataset == 'dsprites': 300 | fixed_idx1 = 87040 # square 301 | fixed_idx2 = 332800 # ellipse 302 | fixed_idx3 = 578560 # heart 303 | 304 | fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1) 305 | fixed_img1 = cuda(fixed_img1, self.use_cuda).unsqueeze(0) 306 | _, fixed_img_z1, _, fixed_img_y1 = encoder(fixed_img1) 307 | 308 | fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2) 309 | fixed_img2 = cuda(fixed_img2, self.use_cuda).unsqueeze(0) 310 | _, fixed_img_z2, _, fixed_img_y2 = encoder(fixed_img2) 311 | 312 | fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3) 313 | fixed_img3 = cuda(fixed_img3, self.use_cuda).unsqueeze(0) 314 | _, fixed_img_z3, _, fixed_img_y3 = encoder(fixed_img3) 315 | 316 | Z = {'fixed_square':fixed_img_z1, 'fixed_ellipse':fixed_img_z2, 317 | 'fixed_heart':fixed_img_z3, 'random_img':random_img_z} 318 | 319 | Y = {'fixed_square':fixed_img_y1, 'fixed_ellipse':fixed_img_y2, 320 | 'fixed_heart':fixed_img_y3, 'random_img':random_img_y} 321 | else: 322 | fixed_idx = 0 323 | fixed_img = self.data_loader.dataset.__getitem__(fixed_idx) 324 | fixed_img = cuda(fixed_img, self.use_cuda).unsqueeze(0) 325 | _, fixed_img_z, _, fixed_img_y = encoder(fixed_img) 326 | 327 | Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z} 328 | Y = {'fixed_img':fixed_img_y, 'random_img':random_img_y} 329 | 330 | gifs = [] 331 | for key in Z.keys(): 332 | z_ori = Z[key] 333 | y = Y[key] 334 | samples = [] 335 | for row in range(self.z_dim): 336 | if loc != -1 and row != loc: 337 | continue 338 | z = z_ori.clone() 339 | for val in interpolation: 340 | z[:, row] = val 341 | sample = decoder(torch.cat([z, y], dim=1)).date 342 | if self.dataset == 'faces': 343 | sample = torch.sigmoid(sample) 344 | samples.append(sample) 345 | gifs.append(sample) 346 | samples = torch.cat(samples, dim=0).cpu() 347 | title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter) 348 | 349 | if self.save_output: 350 | output_dir = os.path.join(self.output_dir, str(self.global_iter)) 351 | os.makedirs(output_dir, exist_ok=True) 352 | gifs = torch.cat(gifs) 353 | gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, self.image_size, self.image_size).transpose(1, 2) 354 | for i, key in enumerate(Z.keys()): 355 | for j, val in enumerate(interpolation): 356 | save_image(tensor=gifs[i][j].cpu(), 357 | filename=os.path.join(output_dir, '{}_{}.jpg'.format(key, j)), 358 | nrow=self.z_dim, pad_value=1) 359 | 360 | grid2gif(os.path.join(output_dir, key+'*.jpg'), 361 | os.path.join(output_dir, key+'.gif'), delay=10) 362 | 363 | self.net_mode(train = True) 364 | 365 | def net_mode(self, train): 366 | if not isinstance(train, bool): 367 | raise('Only bool type is supported. True or False') 368 | 369 | if train: 370 | self.netE.train() 371 | self.netG_Less.train() 372 | self.netG_More.train() 373 | self.netG_Total.train() 374 | self.netD.train() 375 | self.netZ2Y.train() 376 | self.netY2Z.train() 377 | else: 378 | self.netE.eval() 379 | self.netG_Less.eval() 380 | self.netG_More.eval() 381 | self.netG_Total.eval() 382 | self.netD.eval() 383 | self.netZ2Y.eval() 384 | self.netY2Z.eval() 385 | 386 | def save_checkpoint(self, epoch): 387 | netE_path = "checkpoints/{}/netE_{}.pth".format(self.save_name, epoch) 388 | netG_Less_path = "checkpoints/{}/netG_Less_{}.pth".format(self.save_name, epoch) 389 | netG_More_path = "checkpoints/{}/netG_More_{}.pth".format(self.save_name, epoch) 390 | netG_Total_path = "checkpoints/{}/netG_Total_{}.pth".format(self.save_name, epoch) 391 | netD_path = "checkpoints/{}/netD_{}.pth".format(self.save_name, epoch) 392 | netZ2Y_path = "checkpoints/{}/netZ2Y_{}.pth".format(self.save_name, epoch) 393 | netY2Z_path = "checkpoints/{}/netY2Z_{}.pth".format(self.save_name, epoch) 394 | 395 | if not os.path.exists("checkpoints"): 396 | os.mkdir("checkpoints") 397 | 398 | if not os.path.exists("checkpoints/{}".format(self.save_name)): 399 | os.mkdir("checkpoints/{}".format(self.save_name)) 400 | 401 | torch.save(self.netE.state_dict(), netE_path) 402 | torch.save(self.netG_Less.state_dict(), netG_Less_path) 403 | torch.save(self.netG_More.state_dict(), netG_More_path) 404 | torch.save(self.netG_Total.state_dict(), netG_Total_path) 405 | torch.save(self.netD.state_dict(), netD_path) 406 | torch.save(self.netZ2Y.state_dict(), netZ2Y_path) 407 | torch.save(self.netY2Z.state_dict(), netY2Z_path) 408 | 409 | def load_checkpoint(self, epoch): 410 | netE_path = "checkpoints/{}/netE_{}.pth".format(self.save_name, epoch) 411 | netG_Less_path = "checkpoints/{}/netG_Less_{}.pth".format(self.save_name, epoch) 412 | netG_More_path = "checkpoints/{}/netG_More_{}.pth".format(self.save_name, epoch) 413 | netG_Total_path = "checkpoints/{}/netG_Total_{}.pth".format(self.save_name, epoch) 414 | netD_path = "checkpoints/{}/netD_{}.pth".format(self.save_name, epoch) 415 | netZ2Y_path = "checkpoints/{}/netZ2Y_{}.pth".format(self.save_name, epoch) 416 | netY2Z_path = "checkpoints/{}/netY2Z_{}.pth".format(self.save_name, epoch) 417 | 418 | if os.path.isfile(netE_path): 419 | self.netE.load_state_dict(torch.load(netE_path)) 420 | self.netG_Less.load_state_dict(torch.load(netG_Less_path)) 421 | self.netG_More.load_state_dict(torch.load(netG_More_path)) 422 | self.netG_Total.load_state_dict(torch.load(netG_Total_path)) 423 | self.netD.load_state_dict(torch.load(netD_path)) 424 | self.netZ2Y.load_state_dict(torch.load(netZ2Y_path)) 425 | self.netY2Z.load_state_dict(torch.load(netY2Z_path)) 426 | print("=> loaded checkpoint '{} (iter {})'".format(netE_path, self.global_iter)) 427 | else: 428 | print("=> no checkpoint found at '{}'".format(netE_path)) 429 | -------------------------------------------------------------------------------- /Disentanglement/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | 4 | import torch 5 | 6 | def cuda(tensor, uses_cuda): 7 | return tensor.cuda() if uses_cuda else tensor 8 | 9 | 10 | def str2bool(v): 11 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 12 | return True 13 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 14 | return False 15 | else: 16 | raise argparse.ArgumentTypeError('Boolean value expected.') 17 | 18 | 19 | def grid2gif(image_str, output_gif, delay=100): 20 | str1 = 'convert -delay '+str(delay)+' -loop 0 ' + image_str + ' ' + output_gif 21 | subprocess.call(str1, shell=True) 22 | -------------------------------------------------------------------------------- /Paper/poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hujiecpp/InformationCompetingProcess/d5ce61f8af647ee15a8ce55e7ad71c373e48435e/Paper/poster.pdf -------------------------------------------------------------------------------- /Paper/信息竞争式的多样化特征学习.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hujiecpp/InformationCompetingProcess/d5ce61f8af647ee15a8ce55e7ad71c373e48435e/Paper/信息竞争式的多样化特征学习.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is the project page of our paper: 2 | 3 | "Information Competing Process for Learning Diversified Representations." Hu, J., Ji, R., Zhang, S., Sun, X., Ye, Q., Lin, C. W., & Tian, Q. In *NeurIPS 2019.* \[[Paper](http://papers.nips.cc/paper/8490-information-competing-process-for-learning-diversified-representations)\] \[[Poster](https://github.com/hujiecpp/InformationCompetingProcess/tree/master/Paper/poster.pdf)\] \[[中文简介](https://github.com/hujiecpp/InformationCompetingProcess/tree/master/Paper/信息竞争式的多样化特征学习.pdf)\] 4 | 5 | If you have any problem, please feel free to contact us. (hujie.cpp@gmail.com) 6 | 7 | # 1. Supervised Setting: Classification Task 8 | The codes, usages, models and results for classification task can be found in: [./Classification/](https://github.com/hujiecpp/InformationCompetingProcess/tree/master/Classification) 9 | 10 | We implement ICP to train VGG16, GoogLeNet, ResNet20 and DenseNet40 on Cifar10 and Cifar100 datasets. 11 | 12 | Our codes for the classification task are based on [pytorch-cifar](https://github.com/kuangliu/pytorch-cifar) and the models from [KSE](https://github.com/yuchaoli/KSE/tree/master/model). 13 | 14 | # 2. Self-Supervised Setting: Disentanglement Task 15 | The codes, usages, models and results for disentanglement task can be found in: [./Disentanglement/](https://github.com/hujiecpp/InformationCompetingProcess/tree/master/Disentanglement) 16 | 17 | We implement ICP to train Beta-VAE on dSprites, 3D Faces and CelebA datasets. 18 | 19 | Our codes for the disentanglement task are based on [Beta-VAE](https://github.com/1Konny/Beta-VAE). 20 | 21 | The evaluation metric (MIG) for disentanglement are from [beta-tcvae](https://github.com/rtqichen/beta-tcvae), and we thank Ricky for helping us to use the 3D Faces dataset. 22 | 23 | # 3. Citation 24 | If our paper helps your research, please cite it in your publications: 25 | ``` 26 | @inproceedings{hu2019information, 27 | title={Information Competing Process for Learning Diversified Representations}, 28 | author={Hu, Jie and Ji, Rongrong and Zhang, ShengChuan and Sun, Xiaoshuai and Ye, Qixiang and Lin, Chia-Wen and Tian, Qi}, 29 | booktitle={Advances in Neural Information Processing Systems}, 30 | year={2019} 31 | } 32 | ``` 33 | --------------------------------------------------------------------------------