├── .gitignore ├── README.md ├── datasets └── datasets.py ├── main.py ├── misc ├── ELBO.PNG └── monte_carlo.PNG ├── model.py ├── solver.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | summary 3 | checkpoints 4 | datasets/MNIST 5 | 6 | experiments.py 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Variational Information Bottleneck 2 |
3 | 4 | ### Overview 5 | Pytorch implementation of Deep Variational Information Bottleneck([paper], [original code]) 6 | 7 | ![ELBO](misc/ELBO.PNG) 8 | ![monte_carlo](misc/monte_carlo.PNG) 9 |
10 | 11 | ### Dependencies 12 | ``` 13 | python 3.6.4 14 | pytorch 0.3.1.post2 15 | tensorboardX(optional) 16 | tensorflow(optional) 17 | ``` 18 |
19 | 20 | ### Usage 21 | 1. train 22 | ``` 23 | python main.py --mode train --beta 1e-3 --tensorboard True --env_name [NAME] 24 | ``` 25 | 2. test 26 | ``` 27 | python main.py --mode test --env_name [NAME] --load_ckpt best_acc.tar 28 | ``` 29 |
30 | 31 | ### References 32 | 1. Deep Learning and the Information Bottleneck Principle, Tishby et al. 33 | 2. Deep Variational Information Bottleneck, Alemi et al. 34 | 3. Tensorflow Demo : https://github.com/alexalemi/vib_demo 35 | 36 | [paper]: http://arxiv.org/abs/1612.00410 37 | [original code]: https://github.com/alexalemi/vib_demo 38 | -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import torch, os 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | from torchvision.datasets import MNIST 5 | 6 | class UnknownDatasetError(Exception): 7 | def __str__(self): 8 | return "unknown datasets error" 9 | 10 | def return_data(args): 11 | name = args.dataset 12 | dset_dir = args.dset_dir 13 | batch_size = args.batch_size 14 | transform = transforms.Compose([transforms.ToTensor(), 15 | transforms.Normalize((0.5,), (0.5,)),]) 16 | 17 | if 'MNIST' in name : 18 | root = os.path.join(dset_dir,'MNIST') 19 | train_kwargs = {'root':root,'train':True,'transform':transform,'download':True} 20 | test_kwargs = {'root':root,'train':False,'transform':transform,'download':False} 21 | dset = MNIST 22 | 23 | else : raise UnknownDatasetError() 24 | 25 | train_data = dset(**train_kwargs) 26 | train_loader = DataLoader(train_data, 27 | batch_size=batch_size, 28 | shuffle=True, 29 | num_workers=1, 30 | drop_last=True) 31 | 32 | test_data = dset(**test_kwargs) 33 | test_loader = DataLoader(test_data, 34 | batch_size=batch_size, 35 | shuffle=False, 36 | num_workers=1, 37 | drop_last=False) 38 | 39 | data_loader = dict() 40 | data_loader['train']=train_loader 41 | data_loader['test']=test_loader 42 | 43 | return data_loader 44 | 45 | 46 | if __name__ == '__main__' : 47 | import argparse 48 | os.chdir('..') 49 | 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument('--dataset', default='MNIST', type=str) 52 | parser.add_argument('--dset_dir', default='datasets', type=str) 53 | parser.add_argument('--batch_size', default=64, type=int) 54 | args = parser.parse_args() 55 | 56 | data_loader = return_data(args) 57 | import ipdb; ipdb.set_trace() 58 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import argparse 4 | from utils import str2bool 5 | from solver import Solver 6 | 7 | 8 | def main(args): 9 | torch.backends.cudnn.enabled = True 10 | torch.backends.cudnn.benchmark = True 11 | 12 | seed = args.seed 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | np.random.seed(seed) 16 | 17 | np.set_printoptions(precision=4) 18 | torch.set_printoptions(precision=4) 19 | 20 | print() 21 | print('[ARGUMENTS]') 22 | print(args) 23 | print() 24 | 25 | net = Solver(args) 26 | 27 | if args.mode == 'train' : net.train() 28 | elif args.mode == 'test' : net.test(save_ckpt=False) 29 | else : return 0 30 | 31 | if __name__ == "__main__": 32 | 33 | parser = argparse.ArgumentParser(description='TOY VIB') 34 | parser.add_argument('--epoch', default = 200, type=int, help='epoch size') 35 | parser.add_argument('--lr', default = 1e-4, type=float, help='learning rate') 36 | parser.add_argument('--beta', default = 1e-3, type=float, help='beta') 37 | parser.add_argument('--K', default = 256, type=int, help='dimension of encoding Z') 38 | parser.add_argument('--seed', default = 1, type=int, help='random seed') 39 | parser.add_argument('--num_avg', default = 12, type=int, help='the number of samples when\ 40 | perform multi-shot prediction') 41 | parser.add_argument('--batch_size', default = 100, type=int, help='batch size') 42 | parser.add_argument('--env_name', default='main', type=str, help='visdom env name') 43 | parser.add_argument('--dataset', default='MNIST', type=str, help='dataset name') 44 | parser.add_argument('--dset_dir', default='datasets', type=str, help='dataset directory path') 45 | parser.add_argument('--summary_dir', default='summary', type=str, help='summary directory path') 46 | parser.add_argument('--ckpt_dir', default='checkpoints', type=str, help='checkpoint directory path') 47 | parser.add_argument('--load_ckpt',default='', type=str, help='checkpoint name') 48 | parser.add_argument('--cuda',default=True, type=str2bool, help='enable cuda') 49 | parser.add_argument('--mode',default='train', type=str, help='train or test') 50 | parser.add_argument('--tensorboard',default=False, type=str2bool, help='enable tensorboard') 51 | args = parser.parse_args() 52 | 53 | main(args) 54 | -------------------------------------------------------------------------------- /misc/ELBO.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/VIB-pytorch/dad74f78439dad2eabfe3de506b62c35ed0a35de/misc/ELBO.PNG -------------------------------------------------------------------------------- /misc/monte_carlo.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/VIB-pytorch/dad74f78439dad2eabfe3de506b62c35ed0a35de/misc/monte_carlo.PNG -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from torch.autograd import Variable 6 | from utils import cuda 7 | 8 | import time 9 | from numbers import Number 10 | 11 | class ToyNet(nn.Module): 12 | 13 | def __init__(self, K=256): 14 | super(ToyNet, self).__init__() 15 | self.K = K 16 | 17 | self.encode = nn.Sequential( 18 | nn.Linear(784, 1024), 19 | nn.ReLU(True), 20 | nn.Linear(1024, 1024), 21 | nn.ReLU(True), 22 | nn.Linear(1024, 2*self.K)) 23 | 24 | self.decode = nn.Sequential( 25 | nn.Linear(self.K, 10)) 26 | 27 | def forward(self, x, num_sample=1): 28 | if x.dim() > 2 : x = x.view(x.size(0),-1) 29 | 30 | statistics = self.encode(x) 31 | mu = statistics[:,:self.K] 32 | std = F.softplus(statistics[:,self.K:]-5,beta=1) 33 | 34 | encoding = self.reparametrize_n(mu,std,num_sample) 35 | logit = self.decode(encoding) 36 | 37 | if num_sample == 1 : pass 38 | elif num_sample > 1 : logit = F.softmax(logit, dim=2).mean(0) 39 | 40 | return (mu, std), logit 41 | 42 | def reparametrize_n(self, mu, std, n=1): 43 | # reference : 44 | # http://pytorch.org/docs/0.3.1/_modules/torch/distributions.html#Distribution.sample_n 45 | def expand(v): 46 | if isinstance(v, Number): 47 | return torch.Tensor([v]).expand(n, 1) 48 | else: 49 | return v.expand(n, *v.size()) 50 | 51 | if n != 1 : 52 | mu = expand(mu) 53 | std = expand(std) 54 | 55 | eps = Variable(cuda(std.data.new(std.size()).normal_(), std.is_cuda)) 56 | 57 | return mu + eps * std 58 | 59 | def weight_init(self): 60 | for m in self._modules: 61 | xavier_init(self._modules[m]) 62 | 63 | 64 | def xavier_init(ms): 65 | for m in ms : 66 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 67 | nn.init.xavier_uniform(m.weight,gain=nn.init.calculate_gain('relu')) 68 | m.bias.data.zero_() 69 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import argparse 4 | import math 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | from torch.optim import lr_scheduler 10 | from torch.utils.data import DataLoader 11 | from torchvision import transforms 12 | from tensorboardX import SummaryWriter 13 | from utils import cuda, Weight_EMA_Update 14 | from datasets.datasets import return_data 15 | from model import ToyNet 16 | from pathlib import Path 17 | 18 | class Solver(object): 19 | 20 | def __init__(self, args): 21 | self.args = args 22 | 23 | self.cuda = (args.cuda and torch.cuda.is_available()) 24 | self.epoch = args.epoch 25 | self.batch_size = args.batch_size 26 | self.lr = args.lr 27 | self.eps = 1e-9 28 | self.K = args.K 29 | self.beta = args.beta 30 | self.num_avg = args.num_avg 31 | self.global_iter = 0 32 | self.global_epoch = 0 33 | 34 | # Network & Optimizer 35 | self.toynet = cuda(ToyNet(self.K), self.cuda) 36 | self.toynet.weight_init() 37 | self.toynet_ema = Weight_EMA_Update(cuda(ToyNet(self.K), self.cuda),\ 38 | self.toynet.state_dict(), decay=0.999) 39 | 40 | self.optim = optim.Adam(self.toynet.parameters(),lr=self.lr,betas=(0.5,0.999)) 41 | self.scheduler = lr_scheduler.ExponentialLR(self.optim,gamma=0.97) 42 | 43 | self.ckpt_dir = Path(args.ckpt_dir).joinpath(args.env_name) 44 | if not self.ckpt_dir.exists() : self.ckpt_dir.mkdir(parents=True,exist_ok=True) 45 | self.load_ckpt = args.load_ckpt 46 | if self.load_ckpt != '' : self.load_checkpoint(self.load_ckpt) 47 | 48 | # History 49 | self.history = dict() 50 | self.history['avg_acc']=0. 51 | self.history['info_loss']=0. 52 | self.history['class_loss']=0. 53 | self.history['total_loss']=0. 54 | self.history['epoch']=0 55 | self.history['iter']=0 56 | 57 | # Tensorboard 58 | self.tensorboard = args.tensorboard 59 | if self.tensorboard : 60 | self.env_name = args.env_name 61 | self.summary_dir = Path(args.summary_dir).joinpath(args.env_name) 62 | if not self.summary_dir.exists() : self.summary_dir.mkdir(parents=True,exist_ok=True) 63 | self.tf = SummaryWriter(log_dir=self.summary_dir) 64 | self.tf.add_text(tag='argument',text_string=str(args),global_step=self.global_epoch) 65 | 66 | # Dataset 67 | self.data_loader = return_data(args) 68 | 69 | def set_mode(self,mode='train'): 70 | if mode == 'train' : 71 | self.toynet.train() 72 | self.toynet_ema.model.train() 73 | elif mode == 'eval' : 74 | self.toynet.eval() 75 | self.toynet_ema.model.eval() 76 | else : raise('mode error. It should be either train or eval') 77 | 78 | def train(self): 79 | self.set_mode('train') 80 | for e in range(self.epoch) : 81 | self.global_epoch += 1 82 | 83 | for idx, (images,labels) in enumerate(self.data_loader['train']): 84 | self.global_iter += 1 85 | 86 | x = Variable(cuda(images, self.cuda)) 87 | y = Variable(cuda(labels, self.cuda)) 88 | (mu, std), logit = self.toynet(x) 89 | 90 | class_loss = F.cross_entropy(logit,y).div(math.log(2)) 91 | info_loss = -0.5*(1+2*std.log()-mu.pow(2)-std.pow(2)).sum(1).mean().div(math.log(2)) 92 | total_loss = class_loss + self.beta*info_loss 93 | 94 | izy_bound = math.log(10,2) - class_loss 95 | izx_bound = info_loss 96 | 97 | self.optim.zero_grad() 98 | total_loss.backward() 99 | self.optim.step() 100 | self.toynet_ema.update(self.toynet.state_dict()) 101 | 102 | prediction = F.softmax(logit,dim=1).max(1)[1] 103 | accuracy = torch.eq(prediction,y).float().mean() 104 | 105 | if self.num_avg != 0 : 106 | _, avg_soft_logit = self.toynet(x,self.num_avg) 107 | avg_prediction = avg_soft_logit.max(1)[1] 108 | avg_accuracy = torch.eq(avg_prediction,y).float().mean() 109 | else : avg_accuracy = Variable(cuda(torch.zeros(accuracy.size()), self.cuda)) 110 | 111 | if self.global_iter % 100 == 0 : 112 | print('i:{} IZY:{:.2f} IZX:{:.2f}' 113 | .format(idx+1, izy_bound.data[0], izx_bound.data[0]), end=' ') 114 | print('acc:{:.4f} avg_acc:{:.4f}' 115 | .format(accuracy.data[0], avg_accuracy.data[0]), end=' ') 116 | print('err:{:.4f} avg_err:{:.4f}' 117 | .format(1-accuracy.data[0], 1-avg_accuracy.data[0])) 118 | 119 | if self.global_iter % 10 == 0 : 120 | if self.tensorboard : 121 | self.tf.add_scalars(main_tag='performance/accuracy', 122 | tag_scalar_dict={ 123 | 'train_one-shot':accuracy.data[0], 124 | 'train_multi-shot':avg_accuracy.data[0]}, 125 | global_step=self.global_iter) 126 | self.tf.add_scalars(main_tag='performance/error', 127 | tag_scalar_dict={ 128 | 'train_one-shot':1-accuracy.data[0], 129 | 'train_multi-shot':1-avg_accuracy.data[0]}, 130 | global_step=self.global_iter) 131 | self.tf.add_scalars(main_tag='performance/cost', 132 | tag_scalar_dict={ 133 | 'train_one-shot_class':class_loss.data[0], 134 | 'train_one-shot_info':info_loss.data[0], 135 | 'train_one-shot_total':total_loss.data[0]}, 136 | global_step=self.global_iter) 137 | self.tf.add_scalars(main_tag='mutual_information/train', 138 | tag_scalar_dict={ 139 | 'I(Z;Y)':izy_bound.data[0], 140 | 'I(Z;X)':izx_bound.data[0]}, 141 | global_step=self.global_iter) 142 | 143 | 144 | if (self.global_epoch % 2) == 0 : self.scheduler.step() 145 | self.test() 146 | 147 | print(" [*] Training Finished!") 148 | 149 | def test(self, save_ckpt=True): 150 | self.set_mode('eval') 151 | 152 | class_loss = 0 153 | info_loss = 0 154 | total_loss = 0 155 | izy_bound = 0 156 | izx_bound = 0 157 | correct = 0 158 | avg_correct = 0 159 | total_num = 0 160 | for idx, (images,labels) in enumerate(self.data_loader['test']): 161 | 162 | x = Variable(cuda(images, self.cuda)) 163 | y = Variable(cuda(labels, self.cuda)) 164 | (mu, std), logit = self.toynet_ema.model(x) 165 | 166 | class_loss += F.cross_entropy(logit,y,size_average=False).div(math.log(2)) 167 | info_loss += -0.5*(1+2*std.log()-mu.pow(2)-std.pow(2)).sum().div(math.log(2)) 168 | total_loss += class_loss + self.beta*info_loss 169 | total_num += y.size(0) 170 | 171 | izy_bound += math.log(10,2) - class_loss 172 | izx_bound += info_loss 173 | 174 | prediction = F.softmax(logit,dim=1).max(1)[1] 175 | correct += torch.eq(prediction,y).float().sum() 176 | 177 | if self.num_avg != 0 : 178 | _, avg_soft_logit = self.toynet_ema.model(x,self.num_avg) 179 | avg_prediction = avg_soft_logit.max(1)[1] 180 | avg_correct += torch.eq(avg_prediction,y).float().sum() 181 | else : 182 | avg_correct = Variable(cuda(torch.zeros(correct.size()), self.cuda)) 183 | 184 | accuracy = correct/total_num 185 | avg_accuracy = avg_correct/total_num 186 | 187 | izy_bound /= total_num 188 | izx_bound /= total_num 189 | class_loss /= total_num 190 | info_loss /= total_num 191 | total_loss /= total_num 192 | 193 | print('[TEST RESULT]') 194 | print('e:{} IZY:{:.2f} IZX:{:.2f}' 195 | .format(self.global_epoch, izy_bound.data[0], izx_bound.data[0]), end=' ') 196 | print('acc:{:.4f} avg_acc:{:.4f}' 197 | .format(accuracy.data[0], avg_accuracy.data[0]), end=' ') 198 | print('err:{:.4f} avg_erra:{:.4f}' 199 | .format(1-accuracy.data[0], 1-avg_accuracy.data[0])) 200 | print() 201 | 202 | if self.history['avg_acc'] < avg_accuracy.data[0] : 203 | self.history['avg_acc'] = avg_accuracy.data[0] 204 | self.history['class_loss'] = class_loss.data[0] 205 | self.history['info_loss'] = info_loss.data[0] 206 | self.history['total_loss'] = total_loss.data[0] 207 | self.history['epoch'] = self.global_epoch 208 | self.history['iter'] = self.global_iter 209 | if save_ckpt : self.save_checkpoint('best_acc.tar') 210 | 211 | if self.tensorboard : 212 | self.tf.add_scalars(main_tag='performance/accuracy', 213 | tag_scalar_dict={ 214 | 'test_one-shot':accuracy.data[0], 215 | 'test_multi-shot':avg_accuracy.data[0]}, 216 | global_step=self.global_iter) 217 | self.tf.add_scalars(main_tag='performance/error', 218 | tag_scalar_dict={ 219 | 'test_one-shot':1-accuracy.data[0], 220 | 'test_multi-shot':1-avg_accuracy.data[0]}, 221 | global_step=self.global_iter) 222 | self.tf.add_scalars(main_tag='performance/cost', 223 | tag_scalar_dict={ 224 | 'test_one-shot_class':class_loss.data[0], 225 | 'test_one-shot_info':info_loss.data[0], 226 | 'test_one-shot_total':total_loss.data[0]}, 227 | global_step=self.global_iter) 228 | self.tf.add_scalars(main_tag='mutual_information/test', 229 | tag_scalar_dict={ 230 | 'I(Z;Y)':izy_bound.data[0], 231 | 'I(Z;X)':izx_bound.data[0]}, 232 | global_step=self.global_iter) 233 | 234 | self.set_mode('train') 235 | 236 | def save_checkpoint(self, filename='best_acc.tar'): 237 | model_states = { 238 | 'net':self.toynet.state_dict(), 239 | 'net_ema':self.toynet_ema.model.state_dict(), 240 | } 241 | optim_states = { 242 | 'optim':self.optim.state_dict(), 243 | } 244 | states = { 245 | 'iter':self.global_iter, 246 | 'epoch':self.global_epoch, 247 | 'history':self.history, 248 | 'args':self.args, 249 | 'model_states':model_states, 250 | 'optim_states':optim_states, 251 | } 252 | 253 | file_path = self.ckpt_dir.joinpath(filename) 254 | torch.save(states,file_path.open('wb+')) 255 | print("=> saved checkpoint '{}' (iter {})".format(file_path,self.global_iter)) 256 | 257 | def load_checkpoint(self, filename='best_acc.tar'): 258 | file_path = self.ckpt_dir.joinpath(filename) 259 | if file_path.is_file(): 260 | print("=> loading checkpoint '{}'".format(file_path)) 261 | checkpoint = torch.load(file_path.open('rb')) 262 | self.global_epoch = checkpoint['epoch'] 263 | self.global_iter = checkpoint['iter'] 264 | self.history = checkpoint['history'] 265 | 266 | self.toynet.load_state_dict(checkpoint['model_states']['net']) 267 | self.toynet_ema.model.load_state_dict(checkpoint['model_states']['net_ema']) 268 | 269 | print("=> loaded checkpoint '{} (iter {})'".format( 270 | file_path, self.global_iter)) 271 | 272 | else: 273 | print("=> no checkpoint found at '{}'".format(file_path)) 274 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | 5 | 6 | def str2bool(v): 7 | """ 8 | codes from : https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 9 | """ 10 | 11 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 12 | return True 13 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 14 | return False 15 | else: 16 | raise argparse.ArgumentTypeError('Boolean value expected.') 17 | 18 | 19 | def cuda(tensor, is_cuda): 20 | if is_cuda : return tensor.cuda() 21 | else : return tensor 22 | 23 | 24 | class Weight_EMA_Update(object): 25 | 26 | def __init__(self, model, initial_state_dict, decay=0.999): 27 | self.model = model 28 | self.model.load_state_dict(initial_state_dict, strict=True) 29 | self.decay = decay 30 | 31 | def update(self, new_state_dict): 32 | state_dict = self.model.state_dict() 33 | for key in state_dict.keys(): 34 | state_dict[key] = (self.decay)*state_dict[key] + (1-self.decay)*new_state_dict[key] 35 | #state_dict[key] = (1-self.decay)*state_dict[key] + (self.decay)*new_state_dict[key] 36 | 37 | self.model.load_state_dict(state_dict) 38 | --------------------------------------------------------------------------------