├── .gitignore
├── README.md
├── datasets
└── datasets.py
├── main.py
├── misc
├── ELBO.PNG
└── monte_carlo.PNG
├── model.py
├── solver.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | summary
3 | checkpoints
4 | datasets/MNIST
5 |
6 | experiments.py
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Variational Information Bottleneck
2 |
3 |
4 | ### Overview
5 | Pytorch implementation of Deep Variational Information Bottleneck([paper], [original code])
6 |
7 | 
8 | 
9 |
10 |
11 | ### Dependencies
12 | ```
13 | python 3.6.4
14 | pytorch 0.3.1.post2
15 | tensorboardX(optional)
16 | tensorflow(optional)
17 | ```
18 |
19 |
20 | ### Usage
21 | 1. train
22 | ```
23 | python main.py --mode train --beta 1e-3 --tensorboard True --env_name [NAME]
24 | ```
25 | 2. test
26 | ```
27 | python main.py --mode test --env_name [NAME] --load_ckpt best_acc.tar
28 | ```
29 |
30 |
31 | ### References
32 | 1. Deep Learning and the Information Bottleneck Principle, Tishby et al.
33 | 2. Deep Variational Information Bottleneck, Alemi et al.
34 | 3. Tensorflow Demo : https://github.com/alexalemi/vib_demo
35 |
36 | [paper]: http://arxiv.org/abs/1612.00410
37 | [original code]: https://github.com/alexalemi/vib_demo
38 |
--------------------------------------------------------------------------------
/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | import torch, os
2 | from torch.utils.data import DataLoader
3 | from torchvision import transforms
4 | from torchvision.datasets import MNIST
5 |
6 | class UnknownDatasetError(Exception):
7 | def __str__(self):
8 | return "unknown datasets error"
9 |
10 | def return_data(args):
11 | name = args.dataset
12 | dset_dir = args.dset_dir
13 | batch_size = args.batch_size
14 | transform = transforms.Compose([transforms.ToTensor(),
15 | transforms.Normalize((0.5,), (0.5,)),])
16 |
17 | if 'MNIST' in name :
18 | root = os.path.join(dset_dir,'MNIST')
19 | train_kwargs = {'root':root,'train':True,'transform':transform,'download':True}
20 | test_kwargs = {'root':root,'train':False,'transform':transform,'download':False}
21 | dset = MNIST
22 |
23 | else : raise UnknownDatasetError()
24 |
25 | train_data = dset(**train_kwargs)
26 | train_loader = DataLoader(train_data,
27 | batch_size=batch_size,
28 | shuffle=True,
29 | num_workers=1,
30 | drop_last=True)
31 |
32 | test_data = dset(**test_kwargs)
33 | test_loader = DataLoader(test_data,
34 | batch_size=batch_size,
35 | shuffle=False,
36 | num_workers=1,
37 | drop_last=False)
38 |
39 | data_loader = dict()
40 | data_loader['train']=train_loader
41 | data_loader['test']=test_loader
42 |
43 | return data_loader
44 |
45 |
46 | if __name__ == '__main__' :
47 | import argparse
48 | os.chdir('..')
49 |
50 | parser = argparse.ArgumentParser()
51 | parser.add_argument('--dataset', default='MNIST', type=str)
52 | parser.add_argument('--dset_dir', default='datasets', type=str)
53 | parser.add_argument('--batch_size', default=64, type=int)
54 | args = parser.parse_args()
55 |
56 | data_loader = return_data(args)
57 | import ipdb; ipdb.set_trace()
58 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import argparse
4 | from utils import str2bool
5 | from solver import Solver
6 |
7 |
8 | def main(args):
9 | torch.backends.cudnn.enabled = True
10 | torch.backends.cudnn.benchmark = True
11 |
12 | seed = args.seed
13 | torch.manual_seed(seed)
14 | torch.cuda.manual_seed(seed)
15 | np.random.seed(seed)
16 |
17 | np.set_printoptions(precision=4)
18 | torch.set_printoptions(precision=4)
19 |
20 | print()
21 | print('[ARGUMENTS]')
22 | print(args)
23 | print()
24 |
25 | net = Solver(args)
26 |
27 | if args.mode == 'train' : net.train()
28 | elif args.mode == 'test' : net.test(save_ckpt=False)
29 | else : return 0
30 |
31 | if __name__ == "__main__":
32 |
33 | parser = argparse.ArgumentParser(description='TOY VIB')
34 | parser.add_argument('--epoch', default = 200, type=int, help='epoch size')
35 | parser.add_argument('--lr', default = 1e-4, type=float, help='learning rate')
36 | parser.add_argument('--beta', default = 1e-3, type=float, help='beta')
37 | parser.add_argument('--K', default = 256, type=int, help='dimension of encoding Z')
38 | parser.add_argument('--seed', default = 1, type=int, help='random seed')
39 | parser.add_argument('--num_avg', default = 12, type=int, help='the number of samples when\
40 | perform multi-shot prediction')
41 | parser.add_argument('--batch_size', default = 100, type=int, help='batch size')
42 | parser.add_argument('--env_name', default='main', type=str, help='visdom env name')
43 | parser.add_argument('--dataset', default='MNIST', type=str, help='dataset name')
44 | parser.add_argument('--dset_dir', default='datasets', type=str, help='dataset directory path')
45 | parser.add_argument('--summary_dir', default='summary', type=str, help='summary directory path')
46 | parser.add_argument('--ckpt_dir', default='checkpoints', type=str, help='checkpoint directory path')
47 | parser.add_argument('--load_ckpt',default='', type=str, help='checkpoint name')
48 | parser.add_argument('--cuda',default=True, type=str2bool, help='enable cuda')
49 | parser.add_argument('--mode',default='train', type=str, help='train or test')
50 | parser.add_argument('--tensorboard',default=False, type=str2bool, help='enable tensorboard')
51 | args = parser.parse_args()
52 |
53 | main(args)
54 |
--------------------------------------------------------------------------------
/misc/ELBO.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/VIB-pytorch/dad74f78439dad2eabfe3de506b62c35ed0a35de/misc/ELBO.PNG
--------------------------------------------------------------------------------
/misc/monte_carlo.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/VIB-pytorch/dad74f78439dad2eabfe3de506b62c35ed0a35de/misc/monte_carlo.PNG
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.nn.init as init
5 | from torch.autograd import Variable
6 | from utils import cuda
7 |
8 | import time
9 | from numbers import Number
10 |
11 | class ToyNet(nn.Module):
12 |
13 | def __init__(self, K=256):
14 | super(ToyNet, self).__init__()
15 | self.K = K
16 |
17 | self.encode = nn.Sequential(
18 | nn.Linear(784, 1024),
19 | nn.ReLU(True),
20 | nn.Linear(1024, 1024),
21 | nn.ReLU(True),
22 | nn.Linear(1024, 2*self.K))
23 |
24 | self.decode = nn.Sequential(
25 | nn.Linear(self.K, 10))
26 |
27 | def forward(self, x, num_sample=1):
28 | if x.dim() > 2 : x = x.view(x.size(0),-1)
29 |
30 | statistics = self.encode(x)
31 | mu = statistics[:,:self.K]
32 | std = F.softplus(statistics[:,self.K:]-5,beta=1)
33 |
34 | encoding = self.reparametrize_n(mu,std,num_sample)
35 | logit = self.decode(encoding)
36 |
37 | if num_sample == 1 : pass
38 | elif num_sample > 1 : logit = F.softmax(logit, dim=2).mean(0)
39 |
40 | return (mu, std), logit
41 |
42 | def reparametrize_n(self, mu, std, n=1):
43 | # reference :
44 | # http://pytorch.org/docs/0.3.1/_modules/torch/distributions.html#Distribution.sample_n
45 | def expand(v):
46 | if isinstance(v, Number):
47 | return torch.Tensor([v]).expand(n, 1)
48 | else:
49 | return v.expand(n, *v.size())
50 |
51 | if n != 1 :
52 | mu = expand(mu)
53 | std = expand(std)
54 |
55 | eps = Variable(cuda(std.data.new(std.size()).normal_(), std.is_cuda))
56 |
57 | return mu + eps * std
58 |
59 | def weight_init(self):
60 | for m in self._modules:
61 | xavier_init(self._modules[m])
62 |
63 |
64 | def xavier_init(ms):
65 | for m in ms :
66 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
67 | nn.init.xavier_uniform(m.weight,gain=nn.init.calculate_gain('relu'))
68 | m.bias.data.zero_()
69 |
--------------------------------------------------------------------------------
/solver.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import argparse
4 | import math
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | import torch.nn.functional as F
8 | from torch.autograd import Variable
9 | from torch.optim import lr_scheduler
10 | from torch.utils.data import DataLoader
11 | from torchvision import transforms
12 | from tensorboardX import SummaryWriter
13 | from utils import cuda, Weight_EMA_Update
14 | from datasets.datasets import return_data
15 | from model import ToyNet
16 | from pathlib import Path
17 |
18 | class Solver(object):
19 |
20 | def __init__(self, args):
21 | self.args = args
22 |
23 | self.cuda = (args.cuda and torch.cuda.is_available())
24 | self.epoch = args.epoch
25 | self.batch_size = args.batch_size
26 | self.lr = args.lr
27 | self.eps = 1e-9
28 | self.K = args.K
29 | self.beta = args.beta
30 | self.num_avg = args.num_avg
31 | self.global_iter = 0
32 | self.global_epoch = 0
33 |
34 | # Network & Optimizer
35 | self.toynet = cuda(ToyNet(self.K), self.cuda)
36 | self.toynet.weight_init()
37 | self.toynet_ema = Weight_EMA_Update(cuda(ToyNet(self.K), self.cuda),\
38 | self.toynet.state_dict(), decay=0.999)
39 |
40 | self.optim = optim.Adam(self.toynet.parameters(),lr=self.lr,betas=(0.5,0.999))
41 | self.scheduler = lr_scheduler.ExponentialLR(self.optim,gamma=0.97)
42 |
43 | self.ckpt_dir = Path(args.ckpt_dir).joinpath(args.env_name)
44 | if not self.ckpt_dir.exists() : self.ckpt_dir.mkdir(parents=True,exist_ok=True)
45 | self.load_ckpt = args.load_ckpt
46 | if self.load_ckpt != '' : self.load_checkpoint(self.load_ckpt)
47 |
48 | # History
49 | self.history = dict()
50 | self.history['avg_acc']=0.
51 | self.history['info_loss']=0.
52 | self.history['class_loss']=0.
53 | self.history['total_loss']=0.
54 | self.history['epoch']=0
55 | self.history['iter']=0
56 |
57 | # Tensorboard
58 | self.tensorboard = args.tensorboard
59 | if self.tensorboard :
60 | self.env_name = args.env_name
61 | self.summary_dir = Path(args.summary_dir).joinpath(args.env_name)
62 | if not self.summary_dir.exists() : self.summary_dir.mkdir(parents=True,exist_ok=True)
63 | self.tf = SummaryWriter(log_dir=self.summary_dir)
64 | self.tf.add_text(tag='argument',text_string=str(args),global_step=self.global_epoch)
65 |
66 | # Dataset
67 | self.data_loader = return_data(args)
68 |
69 | def set_mode(self,mode='train'):
70 | if mode == 'train' :
71 | self.toynet.train()
72 | self.toynet_ema.model.train()
73 | elif mode == 'eval' :
74 | self.toynet.eval()
75 | self.toynet_ema.model.eval()
76 | else : raise('mode error. It should be either train or eval')
77 |
78 | def train(self):
79 | self.set_mode('train')
80 | for e in range(self.epoch) :
81 | self.global_epoch += 1
82 |
83 | for idx, (images,labels) in enumerate(self.data_loader['train']):
84 | self.global_iter += 1
85 |
86 | x = Variable(cuda(images, self.cuda))
87 | y = Variable(cuda(labels, self.cuda))
88 | (mu, std), logit = self.toynet(x)
89 |
90 | class_loss = F.cross_entropy(logit,y).div(math.log(2))
91 | info_loss = -0.5*(1+2*std.log()-mu.pow(2)-std.pow(2)).sum(1).mean().div(math.log(2))
92 | total_loss = class_loss + self.beta*info_loss
93 |
94 | izy_bound = math.log(10,2) - class_loss
95 | izx_bound = info_loss
96 |
97 | self.optim.zero_grad()
98 | total_loss.backward()
99 | self.optim.step()
100 | self.toynet_ema.update(self.toynet.state_dict())
101 |
102 | prediction = F.softmax(logit,dim=1).max(1)[1]
103 | accuracy = torch.eq(prediction,y).float().mean()
104 |
105 | if self.num_avg != 0 :
106 | _, avg_soft_logit = self.toynet(x,self.num_avg)
107 | avg_prediction = avg_soft_logit.max(1)[1]
108 | avg_accuracy = torch.eq(avg_prediction,y).float().mean()
109 | else : avg_accuracy = Variable(cuda(torch.zeros(accuracy.size()), self.cuda))
110 |
111 | if self.global_iter % 100 == 0 :
112 | print('i:{} IZY:{:.2f} IZX:{:.2f}'
113 | .format(idx+1, izy_bound.data[0], izx_bound.data[0]), end=' ')
114 | print('acc:{:.4f} avg_acc:{:.4f}'
115 | .format(accuracy.data[0], avg_accuracy.data[0]), end=' ')
116 | print('err:{:.4f} avg_err:{:.4f}'
117 | .format(1-accuracy.data[0], 1-avg_accuracy.data[0]))
118 |
119 | if self.global_iter % 10 == 0 :
120 | if self.tensorboard :
121 | self.tf.add_scalars(main_tag='performance/accuracy',
122 | tag_scalar_dict={
123 | 'train_one-shot':accuracy.data[0],
124 | 'train_multi-shot':avg_accuracy.data[0]},
125 | global_step=self.global_iter)
126 | self.tf.add_scalars(main_tag='performance/error',
127 | tag_scalar_dict={
128 | 'train_one-shot':1-accuracy.data[0],
129 | 'train_multi-shot':1-avg_accuracy.data[0]},
130 | global_step=self.global_iter)
131 | self.tf.add_scalars(main_tag='performance/cost',
132 | tag_scalar_dict={
133 | 'train_one-shot_class':class_loss.data[0],
134 | 'train_one-shot_info':info_loss.data[0],
135 | 'train_one-shot_total':total_loss.data[0]},
136 | global_step=self.global_iter)
137 | self.tf.add_scalars(main_tag='mutual_information/train',
138 | tag_scalar_dict={
139 | 'I(Z;Y)':izy_bound.data[0],
140 | 'I(Z;X)':izx_bound.data[0]},
141 | global_step=self.global_iter)
142 |
143 |
144 | if (self.global_epoch % 2) == 0 : self.scheduler.step()
145 | self.test()
146 |
147 | print(" [*] Training Finished!")
148 |
149 | def test(self, save_ckpt=True):
150 | self.set_mode('eval')
151 |
152 | class_loss = 0
153 | info_loss = 0
154 | total_loss = 0
155 | izy_bound = 0
156 | izx_bound = 0
157 | correct = 0
158 | avg_correct = 0
159 | total_num = 0
160 | for idx, (images,labels) in enumerate(self.data_loader['test']):
161 |
162 | x = Variable(cuda(images, self.cuda))
163 | y = Variable(cuda(labels, self.cuda))
164 | (mu, std), logit = self.toynet_ema.model(x)
165 |
166 | class_loss += F.cross_entropy(logit,y,size_average=False).div(math.log(2))
167 | info_loss += -0.5*(1+2*std.log()-mu.pow(2)-std.pow(2)).sum().div(math.log(2))
168 | total_loss += class_loss + self.beta*info_loss
169 | total_num += y.size(0)
170 |
171 | izy_bound += math.log(10,2) - class_loss
172 | izx_bound += info_loss
173 |
174 | prediction = F.softmax(logit,dim=1).max(1)[1]
175 | correct += torch.eq(prediction,y).float().sum()
176 |
177 | if self.num_avg != 0 :
178 | _, avg_soft_logit = self.toynet_ema.model(x,self.num_avg)
179 | avg_prediction = avg_soft_logit.max(1)[1]
180 | avg_correct += torch.eq(avg_prediction,y).float().sum()
181 | else :
182 | avg_correct = Variable(cuda(torch.zeros(correct.size()), self.cuda))
183 |
184 | accuracy = correct/total_num
185 | avg_accuracy = avg_correct/total_num
186 |
187 | izy_bound /= total_num
188 | izx_bound /= total_num
189 | class_loss /= total_num
190 | info_loss /= total_num
191 | total_loss /= total_num
192 |
193 | print('[TEST RESULT]')
194 | print('e:{} IZY:{:.2f} IZX:{:.2f}'
195 | .format(self.global_epoch, izy_bound.data[0], izx_bound.data[0]), end=' ')
196 | print('acc:{:.4f} avg_acc:{:.4f}'
197 | .format(accuracy.data[0], avg_accuracy.data[0]), end=' ')
198 | print('err:{:.4f} avg_erra:{:.4f}'
199 | .format(1-accuracy.data[0], 1-avg_accuracy.data[0]))
200 | print()
201 |
202 | if self.history['avg_acc'] < avg_accuracy.data[0] :
203 | self.history['avg_acc'] = avg_accuracy.data[0]
204 | self.history['class_loss'] = class_loss.data[0]
205 | self.history['info_loss'] = info_loss.data[0]
206 | self.history['total_loss'] = total_loss.data[0]
207 | self.history['epoch'] = self.global_epoch
208 | self.history['iter'] = self.global_iter
209 | if save_ckpt : self.save_checkpoint('best_acc.tar')
210 |
211 | if self.tensorboard :
212 | self.tf.add_scalars(main_tag='performance/accuracy',
213 | tag_scalar_dict={
214 | 'test_one-shot':accuracy.data[0],
215 | 'test_multi-shot':avg_accuracy.data[0]},
216 | global_step=self.global_iter)
217 | self.tf.add_scalars(main_tag='performance/error',
218 | tag_scalar_dict={
219 | 'test_one-shot':1-accuracy.data[0],
220 | 'test_multi-shot':1-avg_accuracy.data[0]},
221 | global_step=self.global_iter)
222 | self.tf.add_scalars(main_tag='performance/cost',
223 | tag_scalar_dict={
224 | 'test_one-shot_class':class_loss.data[0],
225 | 'test_one-shot_info':info_loss.data[0],
226 | 'test_one-shot_total':total_loss.data[0]},
227 | global_step=self.global_iter)
228 | self.tf.add_scalars(main_tag='mutual_information/test',
229 | tag_scalar_dict={
230 | 'I(Z;Y)':izy_bound.data[0],
231 | 'I(Z;X)':izx_bound.data[0]},
232 | global_step=self.global_iter)
233 |
234 | self.set_mode('train')
235 |
236 | def save_checkpoint(self, filename='best_acc.tar'):
237 | model_states = {
238 | 'net':self.toynet.state_dict(),
239 | 'net_ema':self.toynet_ema.model.state_dict(),
240 | }
241 | optim_states = {
242 | 'optim':self.optim.state_dict(),
243 | }
244 | states = {
245 | 'iter':self.global_iter,
246 | 'epoch':self.global_epoch,
247 | 'history':self.history,
248 | 'args':self.args,
249 | 'model_states':model_states,
250 | 'optim_states':optim_states,
251 | }
252 |
253 | file_path = self.ckpt_dir.joinpath(filename)
254 | torch.save(states,file_path.open('wb+'))
255 | print("=> saved checkpoint '{}' (iter {})".format(file_path,self.global_iter))
256 |
257 | def load_checkpoint(self, filename='best_acc.tar'):
258 | file_path = self.ckpt_dir.joinpath(filename)
259 | if file_path.is_file():
260 | print("=> loading checkpoint '{}'".format(file_path))
261 | checkpoint = torch.load(file_path.open('rb'))
262 | self.global_epoch = checkpoint['epoch']
263 | self.global_iter = checkpoint['iter']
264 | self.history = checkpoint['history']
265 |
266 | self.toynet.load_state_dict(checkpoint['model_states']['net'])
267 | self.toynet_ema.model.load_state_dict(checkpoint['model_states']['net_ema'])
268 |
269 | print("=> loaded checkpoint '{} (iter {})'".format(
270 | file_path, self.global_iter))
271 |
272 | else:
273 | print("=> no checkpoint found at '{}'".format(file_path))
274 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.autograd import Variable
4 |
5 |
6 | def str2bool(v):
7 | """
8 | codes from : https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
9 | """
10 |
11 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
12 | return True
13 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
14 | return False
15 | else:
16 | raise argparse.ArgumentTypeError('Boolean value expected.')
17 |
18 |
19 | def cuda(tensor, is_cuda):
20 | if is_cuda : return tensor.cuda()
21 | else : return tensor
22 |
23 |
24 | class Weight_EMA_Update(object):
25 |
26 | def __init__(self, model, initial_state_dict, decay=0.999):
27 | self.model = model
28 | self.model.load_state_dict(initial_state_dict, strict=True)
29 | self.decay = decay
30 |
31 | def update(self, new_state_dict):
32 | state_dict = self.model.state_dict()
33 | for key in state_dict.keys():
34 | state_dict[key] = (self.decay)*state_dict[key] + (1-self.decay)*new_state_dict[key]
35 | #state_dict[key] = (1-self.decay)*state_dict[key] + (self.decay)*new_state_dict[key]
36 |
37 | self.model.load_state_dict(state_dict)
38 |
--------------------------------------------------------------------------------