├── layers ├── __init__.py ├── merge.py └── pool_special.py ├── components ├── __init__.py ├── shortcuts.py └── objectives.py ├── images ├── ssl-mnist-data.png ├── ssl-norb-data.png ├── ssl-svhn-data.png ├── ssl-mnist-sample.png ├── ssl-norb-sample.png └── ssl-svhn-sample.png ├── utils ├── __init__.py ├── create_ssl_data.py ├── others.py └── paramgraphics.py ├── cdgm-mnist-sl.sh ├── cdgm-mnist-ssl_1000.sh ├── cdgm-norb-ssl_1000.sh ├── cdgm-svhn-ssl_1000.sh ├── cdgm-mnist-ssl_100.sh ├── README.md ├── datasets_norb.py ├── datasets.py ├── cdgm_x2y_xy2z_zy2x_sl.py └── cdgm_x2y_xy2z_zy2x.py /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .pool_special import * 2 | from .merge import * -------------------------------------------------------------------------------- /components/__init__.py: -------------------------------------------------------------------------------- 1 | from .shortcuts import * 2 | from .objectives import * -------------------------------------------------------------------------------- /images/ssl-mnist-data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/mmdcgm-ssl/HEAD/images/ssl-mnist-data.png -------------------------------------------------------------------------------- /images/ssl-norb-data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/mmdcgm-ssl/HEAD/images/ssl-norb-data.png -------------------------------------------------------------------------------- /images/ssl-svhn-data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/mmdcgm-ssl/HEAD/images/ssl-svhn-data.png -------------------------------------------------------------------------------- /images/ssl-mnist-sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/mmdcgm-ssl/HEAD/images/ssl-mnist-sample.png -------------------------------------------------------------------------------- /images/ssl-norb-sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/mmdcgm-ssl/HEAD/images/ssl-norb-sample.png -------------------------------------------------------------------------------- /images/ssl-svhn-sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/mmdcgm-ssl/HEAD/images/ssl-svhn-sample.png -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .paramgraphics import * 2 | from .others import * 3 | from .create_ssl_data import * -------------------------------------------------------------------------------- /cdgm-mnist-sl.sh: -------------------------------------------------------------------------------- 1 | THEANO_FLAGS=device=$1 python cdgm_x2y_xy2z_zy2x_sl.py -name 'sl-mnist' -dataset mnist_real -flag evaluation -preprocess none -batch_norm_classifier true -top_mlp false -mlp_size 256 -nlayers_cla 5 -nk_cla 32,64,64,128,128 -str_cla 1,1,1,1,1 -ps_cla 2,1,2,1,1 -dk_cla 5,3,3,3,3 -pad_cla valid,same,valid,same,same -nonlin_cla rectify,rectify,rectify,rectify,rectify -dr_cla 0.5,0,0.5,0,0 -nz 100 -batch_norm_dgm false -nlayers_enc 5 -nk_enc 32,32,64,64,64 -dk_enc 5,3,3,3,3 -pad_enc valid,same,valid,same,same -str_enc 1,1,1,1,1 -ps_enc 2,1,2,1,1 -nonlin_enc rectify,rectify,rectify,rectify,rectify -dr_enc 0,0,0,0,0 -nlayers_dec 5 -nk_dec 64,64,32,32,1 -dk_dec 3,3,3,3,5 -pad_dec same,same,full,same,full -str_dec 1,1,1,1,1 -up_method none,none,unpool,none,unpool -ps_dec 1,1,2,1,2 -nonlin_dec rectify,rectify,rectify,rectify,sigmoid -dr_dec 0,0,0,0,0 -lr 3e-4 -nepochs 3000 -anneal_lr_epoch 1500 -anneal_lr_factor .995 -every_anneal 1 -delta 1.0 -batch_size 600 -alpha_decay 1e-4 -alpha .1 -------------------------------------------------------------------------------- /cdgm-mnist-ssl_1000.sh: -------------------------------------------------------------------------------- 1 | THEANO_FLAGS=device=$1,lib.cnmem=0.2 python cdgm_x2y_xy2z_zy2x.py -name 'ssl-mnist-1000' -dataset mnist_real -flag evaluation -ssl_data_seed $2 -preprocess none -batch_norm_classifier true -top_mlp false -nlayers_cla 5 -nk_cla 32,64,64,128,10 -str_cla 1,1,1,1,1 -ps_cla 2,1,2,1,1 -dk_cla 5,3,3,3,1 -pad_cla valid,same,valid,same,same -nonlin_cla rectify,rectify,rectify,rectify,rectify -dr_cla 0.5,0,0.5,0,0 -nz 100 -batch_norm_dgm false -nlayers_enc 5 -nk_enc 32,32,64,64,64 -dk_enc 5,3,3,3,3 -pad_enc valid,same,valid,same,same -str_enc 1,1,1,1,1 -ps_enc 2,1,2,1,1 -nonlin_enc rectify,rectify,rectify,rectify,rectify -dr_enc 0,0,0,0,0 -nlayers_dec 5 -nk_dec 64,64,32,32,1 -dk_dec 3,3,3,3,5 -pad_dec same,same,full,same,full -str_dec 1,1,1,1,1 -up_method none,none,unpool,none,unpool -ps_dec 1,1,2,1,2 -nonlin_dec rectify,rectify,rectify,rectify,sigmoid -dr_dec 0,0,0,0,0 -lr 3e-4 -nepochs 3000 -anneal_lr_epoch 1500 -anneal_lr_factor .995 -delta 1.0 -num_labelled_per_batch 250 -num_labelled 1000 -batch_size 600 -alpha_decay 1e-4 -alpha_hinge 1. -alpha_hat .3 -alpha_reg 1e-3 -alpha .1 -alpha_straight_through 3e-4 -------------------------------------------------------------------------------- /cdgm-norb-ssl_1000.sh: -------------------------------------------------------------------------------- 1 | THEANO_FLAGS=device=$1,lib.cnmem=.9 python cdgm_x2y_xy2z_zy2x.py -name 'ssl-norb-1000' -dataset norb -flag evaluation -ssl_data_seed $2 -preprocess none -batch_norm_classifier true -top_mlp false -nlayers_cla 6 -nk_cla 32,32,64,64,128,10 -str_cla 1,1,1,1,1,1 -ps_cla 1,2,1,2,1,1 -dk_cla 3,3,3,3,3,1 -pad_cla same,same,same,same,same,same -nonlin_cla rectify,rectify,rectify,rectify,rectify,rectify -dr_cla 0,0.2,0,0.2,0,0.2 -nz 100 -batch_norm_dgm false -nlayers_enc 5 -nk_enc 32,64,64,128,128 -dk_enc 5,3,3,3,3 -pad_enc same,same,same,same,same -str_enc 1,1,1,1,1 -ps_enc 2,1,2,1,2 -nonlin_enc rectify,rectify,rectify,rectify,rectify -dr_enc 0,0,0,0,0 -nlayers_dec 5 -nk_dec 128,64,64,32,1 -dk_dec 3,3,3,3,5 -pad_dec same,same,same,same,same -str_dec 1,1,1,1,1 -up_method unpool,none,unpool,none,unpool -ps_dec 2,1,2,1,2 -nonlin_dec rectify,rectify,rectify,rectify,sigmoid -dr_dec 0,0,0,0,0 -lr 3e-4 -nepochs 3000 -anneal_lr_epoch 2000 -anneal_lr_factor .995 -num_labelled_per_batch 1000 -num_labelled 1000 -batch_size 2000 -alpha_decay 1e-4 -alpha_hinge 1. -alpha_hat 0.3 -alpha_reg 1e-3 -alpha_straight_through 3e-5 -------------------------------------------------------------------------------- /cdgm-svhn-ssl_1000.sh: -------------------------------------------------------------------------------- 1 | THEANO_FLAGS=device=$1,lib.cnmem=1 python cdgm_x2y_xy2z_zy2x.py -name 'ssl-svhn-1000' -dataset svhn -flag evaluation -ssl_data_seed $2 -preprocess none -batch_norm_classifier true -top_mlp false -mlp_size 256 -nlayers_cla 6 -nk_cla 32,32,64,64,128,128 -str_cla 1,1,1,1,1,1 -ps_cla 1,2,1,2,1,1 -dk_cla 3,3,3,3,3,3 -pad_cla same,same,same,same,same,same -nonlin_cla rectify,rectify,rectify,rectify,rectify,rectify -dr_cla 0,0.2,0,0.2,0,0.2 -nz 128 -batch_norm_dgm false -nlayers_enc 5 -nk_enc 32,64,64,128,128 -dk_enc 5,3,3,3,3 -pad_enc same,same,same,same,same -str_enc 1,1,1,1,1 -ps_enc 2,1,2,1,2 -nonlin_enc rectify,rectify,rectify,rectify,rectify -dr_enc 0,0,0,0,0 -nlayers_dec 5 -nk_dec 128,64,64,32,3 -dk_dec 3,3,3,3,5 -pad_dec same,same,same,same,same -str_dec 1,1,1,1,1 -up_method unpool,none,unpool,none,unpool -ps_dec 2,1,2,1,2 -nonlin_dec rectify,rectify,rectify,rectify,sigmoid -dr_dec 0,0,0,0,0 -lr 3e-4 -nepochs 500 -anneal_lr_epoch 250 -anneal_lr_factor .99 -num_labelled_per_batch 500 -num_labelled 1000 -batch_size 1000 -alpha_decay 1e-4 -alpha_hinge 1. -alpha_hat 0.3 -alpha_reg 1e-3 -alpha_straight_through 1e-4 -------------------------------------------------------------------------------- /cdgm-mnist-ssl_100.sh: -------------------------------------------------------------------------------- 1 | THEANO_FLAGS=device=$1,lib.cnmem=0.2 python cdgm_x2y_xy2z_zy2x.py -name 'ssl-mnist-100' -dataset mnist_real -flag evaluation -ssl_data_seed $2 -preprocess none -batch_norm_classifier true -top_mlp false -mlp_size 256 -nlayers_cla 5 -nk_cla 32,64,64,128,128 -str_cla 1,1,1,1,1 -ps_cla 2,1,2,1,1 -dk_cla 5,3,3,3,3 -pad_cla valid,same,valid,same,same -nonlin_cla rectify,rectify,rectify,rectify,rectify -dr_cla 0.5,0,0.5,0,0 -nz 100 -batch_norm_dgm false -nlayers_enc 5 -nk_enc 32,32,64,64,64 -dk_enc 5,3,3,3,3 -pad_enc valid,same,valid,same,same -str_enc 1,1,1,1,1 -ps_enc 2,1,2,1,1 -nonlin_enc rectify,rectify,rectify,rectify,rectify -dr_enc 0,0,0,0,0 -nlayers_dec 5 -nk_dec 64,64,32,32,1 -dk_dec 3,3,3,3,5 -pad_dec same,same,full,same,full -str_dec 1,1,1,1,1 -up_method none,none,unpool,none,unpool -ps_dec 1,1,2,1,2 -nonlin_dec rectify,rectify,rectify,rectify,sigmoid -dr_dec 0,0,0,0,0 -lr 3e-4 -nepochs 3000 -anneal_lr_epoch 1500 -anneal_lr_factor .995 -every_anneal 1 -delta 1.0 -num_labelled_per_batch 100 -num_labelled 100 -batch_size 600 -alpha_decay 1e-4 -alpha_hinge 1. -alpha_hat .3 -alpha_reg 1e-3 -alpha .1 -alpha_straight_through 3e-4 -------------------------------------------------------------------------------- /layers/merge.py: -------------------------------------------------------------------------------- 1 | import lasagne 2 | from lasagne import init 3 | from lasagne import nonlinearities 4 | 5 | import theano.tensor as T 6 | import theano 7 | import numpy as np 8 | import theano.tensor.extra_ops as Textra 9 | 10 | __all__ = [ 11 | "ConvConcatLayer", # 12 | "MLPConcatLayer", # 13 | ] 14 | 15 | 16 | class ConvConcatLayer(lasagne.layers.MergeLayer): 17 | ''' 18 | concatenate a tensor and a vector on feature map axis 19 | ''' 20 | def __init__(self, incomings, num_cls, **kwargs): 21 | super(ConvConcatLayer, self).__init__(incomings, **kwargs) 22 | self.num_cls = num_cls 23 | 24 | def get_output_shape_for(self, input_shapes): 25 | res = list(input_shapes[0]) 26 | res[1] += self.num_cls 27 | return tuple(res) 28 | 29 | def get_output_for(self, input, **kwargs): 30 | x, y = input 31 | if y.ndim == 1: 32 | y = T.extra_ops.to_one_hot(y, self.num_cls) 33 | if y.ndim == 2: 34 | y = y.dimshuffle(0, 1, 'x', 'x') 35 | assert y.ndim == 4 36 | return T.concatenate([x, y*T.ones((x.shape[0], y.shape[1], x.shape[2], x.shape[3]))], axis=1) 37 | 38 | class MLPConcatLayer(lasagne.layers.MergeLayer): 39 | ''' 40 | concatenate a matrix and a vector on feature axis 41 | ''' 42 | def __init__(self, incomings, num_cls, **kwargs): 43 | super(MLPConcatLayer, self).__init__(incomings, **kwargs) 44 | self.num_cls = num_cls 45 | 46 | def get_output_shape_for(self, input_shapes): 47 | res = list(input_shapes[0]) 48 | res[1] += self.num_cls 49 | return tuple(res) 50 | 51 | def get_output_for(self, input, **kwargs): 52 | x, y = input 53 | if y.ndim == 1: 54 | y = T.extra_ops.to_one_hot(y, self.num_cls) 55 | assert y.ndim == 2 56 | return T.concatenate([x, y], axis=1) -------------------------------------------------------------------------------- /utils/create_ssl_data.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Create semi-supervised datasets for different models 3 | ''' 4 | import numpy as np 5 | 6 | def create_ssl_data(x, y, n_classes, n_labelled, seed): 7 | # 'x': data matrix, nxk 8 | # 'y': label vector, n 9 | # 'n_classes': number of classes 10 | # 'n_labelled': number of labelled data 11 | # 'seed': random seed 12 | 13 | # check input 14 | if n_labelled%n_classes != 0: 15 | print n_labelled 16 | print n_classes 17 | raise("n_labelled (wished number of labelled samples) not divisible by n_classes (number of classes)") 18 | n_labels_per_class = n_labelled/n_classes 19 | 20 | rng = np.random.RandomState(seed) 21 | index = rng.permutation(x.shape[0]) 22 | x = x[index] 23 | y = y[index] 24 | 25 | # select first several data per class 26 | data_labelled = [0]*n_classes 27 | index_labelled = [] 28 | index_unlabelled = [] 29 | for i in xrange(x.shape[0]): 30 | if data_labelled[y[i]] < n_labels_per_class: 31 | data_labelled[y[i]] += 1 32 | index_labelled.append(i) 33 | else: 34 | index_unlabelled.append(i) 35 | 36 | x_labelled = x[index_labelled] 37 | y_labelled = y[index_labelled] 38 | x_unlabelled = x[index_unlabelled] 39 | y_unlabelled = y[index_unlabelled] 40 | return x_labelled, y_labelled, x_unlabelled, y_unlabelled 41 | 42 | 43 | def create_ssl_data_subset(x, y, n_classes, n_labelled, n_labelled_per_time, seed): 44 | assert n_labelled%n_labelled_per_time==0 45 | times = n_labelled/n_labelled_per_time 46 | x_labelled, y_labelled, x_unlabelled, y_unlabelled = create_ssl_data(x, y, n_classes, n_labelled_per_time, seed) 47 | while (times > 1): 48 | x_labelled_new, y_labelled_new, x_unlabelled, y_unlabelled = create_ssl_data(x_unlabelled, y_unlabelled, n_classes, n_labelled_per_time, seed) 49 | x_labelled = np.vstack((x_labelled, x_labelled_new)) 50 | y_labelled = np.hstack((y_labelled, y_labelled_new)) 51 | times -= 1 52 | return x_labelled, y_labelled, x_unlabelled, y_unlabelled -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Max-margin Deep Conditional Generative Models for Semi-Supervised Learning 2 | ## [Chongxuan Li](https://github.com/zhenxuan00), Jun Zhu and Bo Zhang 3 | 4 | Full [paper](https://arxiv.org/abs/1611.07119), a journal version of our NIPS15 paper (original [paper](https://arxiv.org/abs/1504.06787) and [code](https://github.com/zhenxuan00/mmdgm)). A novel class-condional variants of mmDGMs is proposed. 5 | 6 | ## Summary of Max-margin Deep Conditional Generative Models (mmDCGMs) 7 | 8 | - We boost the effectiveness and efficiency of DGMs in semi-supervised learning by 9 | - Employing advanced CNNs as the x2y, xy2z and zy2x networks 10 | - Approximating the posterior inference of labels 11 | - Proposing powerful max-margin discriminative losses for labeled and unlabeled data 12 | - and the arrived mmDCGMs can 13 | - Perform efficient inference: constant time with respect to the number of classes 14 | - Achieve state-of-the-art classification results on sevarl benchmarks: MNIST, SVHN and NORB with 1000 labels and MNIST with full labels 15 | - Disentangle classes and styles on raw images without preprocessing like PCA given small amount of labels 16 | 17 | ## Some libs we used in our experiments 18 | > Python 19 | > Numpy 20 | > Scipy 21 | > [Theano](https://github.com/Theano/Theano) 22 | > [Lasagne](https://github.com/Lasagne/Lasagne) 23 | > [Parmesan](https://github.com/casperkaae/parmesan) 24 | 25 | ## State-of-the-art results on MNIST, SVHN and NORB datasets with 1000 labels and excellent results competitive to best CNNS given all labels on MNIST 26 | 27 | > chmod +x *.sh 28 | 29 | > ./cdgm-svhn-ssl_1000.sh gpu0 (Run .sh files to obtain corresponding results) 30 | 31 | > For small norb dataset, please download the raw images in .MAT format from [http://www.cs.nyu.edu/~ylclab/data/norb-v1.0-small/](http://www.cs.nyu.edu/~ylclab/data/norb-v1.0-small/) and run datasets_norb.convert_orig_to_np() to convert it into numpy format. 32 | 33 | > See Table 6 and Table 7 in the paper for the classfication results. 34 | 35 | ## Class conditional generation of raw images given a few labels 36 | 37 | ### Results on MNIST given 100 labels (left: 100 labeled data sorted by class, right: samples, where each row shares same class and each column shares same style.) 38 | 39 | 40 | 41 | ### Results on SVHN given 1000 labels 42 | 43 | 44 | ### Results on small NORB given 1000 labels 45 | 46 | -------------------------------------------------------------------------------- /utils/others.py: -------------------------------------------------------------------------------- 1 | import shutil, gzip, os, cPickle, time, math, operator, argparse 2 | 3 | import numpy as np 4 | import theano.tensor as T 5 | import theano, lasagne 6 | 7 | 8 | def get_pad(pad): 9 | if pad not in ['same', 'valid', 'full']: 10 | pad = tuple(map(int, pad.split('-'))) 11 | return pad 12 | 13 | def get_pad_list(pad_list): 14 | re_list = [] 15 | for p in pad_list: 16 | re_list.append(get_pad(p)) 17 | return re_list 18 | 19 | # nonlinearities 20 | def get_nonlin(nonlin): 21 | if nonlin == 'rectify': 22 | return lasagne.nonlinearities.rectify 23 | elif nonlin == 'leaky_rectify': 24 | return lasagne.nonlinearities.LeakyRectify(0.1) 25 | elif nonlin == 'tanh': 26 | return lasagne.nonlinearities.tanh 27 | elif nonlin == 'sigmoid': 28 | return lasagne.nonlinearities.sigmoid 29 | elif nonlin == 'maxout': 30 | return 'maxout' 31 | elif nonlin == 'none': 32 | return lasagne.nonlinearities.identity 33 | else: 34 | raise ValueError('invalid non-linearity \'' + nonlin + '\'') 35 | def get_nonlin_list(nonlin_list): 36 | re_list = [] 37 | for n in nonlin_list: 38 | re_list.append(get_nonlin(n)) 39 | return re_list 40 | 41 | def bernoullisample(x): 42 | return np.random.binomial(1,x,size=x.shape).astype(theano.config.floatX) 43 | 44 | def build_log_file(args, filename_script, extra=None): 45 | res_out = args.outfolder 46 | res_out += '_' 47 | res_out += args.name 48 | res_out += '_' 49 | if extra is not None: 50 | res_out += extra 51 | res_out += '_' 52 | res_out += str(int(time.time())) 53 | if not os.path.exists(res_out): 54 | os.makedirs(res_out) 55 | 56 | # write commandline parameters to header of logfile 57 | args_dict = vars(args) 58 | sorted_args = sorted(args_dict.items(), key=operator.itemgetter(0)) 59 | description = [] 60 | description.append('######################################################') 61 | description.append('# --Commandline Params--') 62 | for name, val in sorted_args: 63 | description.append("# " + name + ":\t" + str(val)) 64 | description.append('######################################################') 65 | 66 | logfile = os.path.join(res_out, 'logfile.log') 67 | model_out = os.path.join(res_out, 'model') 68 | with open(logfile,'w') as f: 69 | for l in description: 70 | f.write(l + '\n') 71 | return logfile, res_out 72 | 73 | def array2file_2D(array,logfile): 74 | assert len(array.shape) == 2, array.shape 75 | with open(logfile,'a') as f: 76 | for i in xrange(array.shape[0]): 77 | for j in xrange(array.shape[1]): 78 | f.write(str(array[i][j])+' ') 79 | f.write('\n') 80 | 81 | def printarray_2D(array, precise=2): 82 | assert len(array.shape) == 2, array.shape 83 | format = '%.'+str(precise)+'f' 84 | for i in xrange(array.shape[0]): 85 | for j in xrange(array.shape[1]): 86 | print format %array[i][j], 87 | print -------------------------------------------------------------------------------- /components/shortcuts.py: -------------------------------------------------------------------------------- 1 | ''' 2 | shortcuts for compsited layers 3 | ''' 4 | import numpy as np 5 | import theano.tensor as T 6 | import theano 7 | import lasagne 8 | 9 | from parmesan.distributions import log_stdnormal, log_normal2, log_bernoulli 10 | 11 | import sys 12 | sys.path.append("..") 13 | from layers.pool_special import UnPoolLayer, UnPoolMaskLayer, MaxPoolLocationLayer, RepeatUnPoolLayer 14 | from layers.merge import ConvConcatLayer, MLPConcatLayer 15 | 16 | # convolutional layer 17 | # following optional batch normalization, pooling and dropout 18 | def convlayer(l,bn,dr,ps,n_kerns,d_kerns,nonlinearity,pad,stride,name,output_mask=False,batch_size_act=0,W=lasagne.init.GlorotUniform(),b=lasagne.init.Constant(0.)): 19 | mask = None 20 | l = lasagne.layers.Conv2DLayer(l, num_filters=n_kerns, filter_size=(d_kerns,d_kerns), stride=stride, pad=pad, name="Conv-"+name, W=W, b=b, nonlinearity=nonlinearity) 21 | if bn: 22 | l = lasagne.layers.batch_norm(l, name="BN-"+name) 23 | if ps > 1: 24 | if output_mask: 25 | mask = MaxPoolLocationLayer(l,factor=(ps,ps),batch_size=batch_size_act) 26 | l = lasagne.layers.MaxPool2DLayer(l, pool_size=(ps,ps), name="Pool"+name) 27 | if dr > 0: 28 | l = lasagne.layers.DropoutLayer(l, p=dr, name="Drop-"+name) 29 | return l, mask 30 | 31 | # unpooling and convolutional layer 32 | # following optional batch normalization and dropout 33 | def unpoolconvlayer(l,bn,dr,ps,n_kerns,d_kerns,nonlinearity,pad,stride,name,type_='unpool',mask=None,W=lasagne.init.GlorotUniform(),b=lasagne.init.Constant(0.), noise_level=0): 34 | if ps > 1: 35 | if type_ == 'unpool': 36 | l = UnPoolLayer(incoming=l, factor=(ps,ps), name="UP-"+name) 37 | elif type_ == 'repeat': 38 | l = RepeatUnPoolLayer(incoming=l, factor=(ps,ps), name="UP_REP-"+name) 39 | elif type_ == 'unpoolmask': 40 | l = UnPoolMaskLayer(incoming=l, mask=mask, factor=(ps,ps), name="UP_MUSK-"+name, noise_level=noise_level) 41 | l = lasagne.layers.Conv2DLayer(l, num_filters=n_kerns, filter_size=(d_kerns,d_kerns), stride=stride, pad=pad, name="Conv-"+name, W=W, b=b, nonlinearity=nonlinearity) 42 | if bn: 43 | l = lasagne.layers.batch_norm(l, name="BN-"+name) 44 | if dr > 0: 45 | l = lasagne.layers.DropoutLayer(l, p=dr, name="Drop-"+name) 46 | return l 47 | 48 | # fractional strided convolutional layer 49 | # following optional batch normalization and dropout 50 | def fractionalstridedlayer(l,bn,dr,n_kerns,d_kerns,nonlinearity,pad,stride,name,W=lasagne.init.GlorotUniform(),b=lasagne.init.Constant(0.)): 51 | # print bn,dr,n_kerns,d_kerns,nonlinearity,pad,stride,name 52 | l = lasagne.layers.TransposedConv2DLayer(l, num_filters=n_kerns, filter_size=(d_kerns,d_kerns), stride=stride, crop=pad, name="FS_Conv-"+name, W=W, b=b, nonlinearity=nonlinearity) 53 | if bn: 54 | l = lasagne.layers.batch_norm(l, name="BN-"+name) 55 | if dr > 0: 56 | l = lasagne.layers.DropoutLayer(l, p=dr, name="Drop-"+name) 57 | return l 58 | 59 | # mlp layer 60 | # following optional batch normalization and dropout 61 | def mlplayer(l,bn,dr,num_units,nonlinearity,name): 62 | l = lasagne.layers.DenseLayer(l,num_units=num_units,nonlinearity=nonlinearity,name="MLP-"+name) 63 | if bn: 64 | l = lasagne.layers.batch_norm(l, name="BN-"+name) 65 | if dr > 0: 66 | l = lasagne.layers.DropoutLayer(l, p=dr, name="Drop-"+name) 67 | return l 68 | -------------------------------------------------------------------------------- /datasets_norb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io 3 | import os 4 | import gzip 5 | import cPickle 6 | 7 | # refer to kingma's code nips14-ssl: https://github.com/dpkingma/nips14-ssl 8 | path = '/home/chongxuan/mfs/data/small_norb/np/' 9 | 10 | def load_numpy_dat(size=48): 11 | with gzip.open(path+'train_dat_'+str(size)+'.pkl.gz', 'rb') as f: 12 | train_dat = cPickle.load(f) 13 | with gzip.open(path+'test_dat_'+str(size)+'.pkl.gz', 'rb') as f: 14 | test_dat = cPickle.load(f) 15 | return train_dat, test_dat 16 | 17 | def load_numpy_cat(): 18 | with gzip.open(path+'train_cat.pkl.gz', 'rb') as f: 19 | train_cat = cPickle.load(f) 20 | with gzip.open(path+'test_cat.pkl.gz', 'rb') as f: 21 | test_cat = cPickle.load(f) 22 | return train_cat, test_cat 23 | 24 | def load_numpy_info(): 25 | with gzip.open(path+'train_info.pkl.gz', 'rb') as f: 26 | train_info = cPickle.load(f) 27 | with gzip.open(path+'test_info.pkl.gz', 'rb') as f: 28 | test_info = cPickle.load(f) 29 | return train_info, test_info 30 | 31 | # Load dataset with 50 subclasses, merged to single matrices 32 | def load_numpy_subclasses(size=48, normalize=False, centered=False, convert_to_five=True): 33 | train_dat, test_dat = load_numpy_dat(size) 34 | train_info, test_info = load_numpy_info() 35 | train_cat, test_cat = load_numpy_cat() 36 | 37 | n = train_dat.shape[0] 38 | 39 | n_class = 5 #number of classes 40 | n_ipc = 10 #number of instances per class 41 | 42 | train_x_left = train_dat[:,0].reshape((n, -1)).T 43 | train_x_right = train_dat[:,1].reshape((n, -1)).T 44 | train_y = (train_cat[:]*n_ipc + train_info[:,0]).reshape(1, n) # computes which of the 50 subclasses 45 | 46 | test_x_left = test_dat[:,0].reshape((n, -1)).T 47 | test_x_right = test_dat[:,1].reshape((n, -1)).T 48 | test_y = (test_cat[:]*n_ipc + test_info[:,0]).reshape(1, n) 49 | 50 | x = np.hstack((train_x_left, train_x_right, test_x_left, test_x_right)) # computes which of the 50 subclasses 51 | y = np.hstack((train_y, train_y, test_y, test_y)) 52 | 53 | if convert_to_five: 54 | y = y/10 55 | if normalize: 56 | x = x/256.0 57 | if centered: 58 | x = x - x.mean(axis=0,keepdims=True) 59 | return x, y 60 | 61 | # Original data to numpy-format data 62 | def convert_orig_to_np(): 63 | from pylearn2.datasets.filetensor import read 64 | import gzip 65 | import cPickle 66 | # Load data 67 | path_orig = './data/small_norb/mat/' 68 | prefix_train = path_orig+'smallnorb-5x46789x9x18x6x2x96x96-training-' 69 | train_cat = read(gzip.open(prefix_train+'cat.mat.gz')) 70 | train_dat = read(gzip.open(prefix_train+'dat.mat.gz')) 71 | train_info = read(gzip.open(prefix_train+'info.mat.gz')) 72 | prefix_test = path_orig+'smallnorb-5x01235x9x18x6x2x96x96-testing-' 73 | test_cat = read(gzip.open(prefix_test+'cat.mat.gz')) 74 | test_dat = read(gzip.open(prefix_test+'dat.mat.gz')) 75 | test_info = read(gzip.open(prefix_test+'info.mat.gz')) 76 | 77 | # Save originals matrices to file 78 | files = (('train_cat', train_cat), ('train_dat_96', train_dat), ('train_info', train_info), ('test_cat', test_cat), ('test_dat_96', test_dat), ('test_info', test_info)) 79 | for fname, tensor in files: 80 | print 'Saving to ', fname, '...' 81 | with gzip.open(path+fname+'.pkl.gz','wb') as f: 82 | cPickle.dump(tensor, f) 83 | 84 | # Save downscaled version too 85 | w = 48 86 | files = (('train_dat', train_dat),('test_dat', test_dat)) 87 | for fname, tensor in files: 88 | print 'Generating downscaled version ' + fname + '...' 89 | left = reshape_images(tensor[:,0,:,:], (w,w)) 90 | right = reshape_images(tensor[:,1,:,:], (w,w)) 91 | result = np.zeros((tensor.shape[0], 2, w,w), dtype=np.uint8) 92 | result[:,0,:,:] = left 93 | result[:,1,:,:] = right 94 | f = gzip.open(path+fname+'_'+str(w)+'.pkl.gz', 'wb') 95 | cPickle.dump(result, f) 96 | f.close() 97 | 98 | w = 32 99 | files = (('train_dat', train_dat),('test_dat', test_dat)) 100 | for fname, tensor in files: 101 | print 'Generating downscaled version ' + fname + '...' 102 | left = reshape_images(tensor[:,0,:,:], (w,w)) 103 | right = reshape_images(tensor[:,1,:,:], (w,w)) 104 | result = np.zeros((tensor.shape[0], 2, w, w), dtype=np.uint8) 105 | result[:,0,:,:] = left 106 | result[:,1,:,:] = right 107 | f = gzip.open(path+fname+'_'+str(w)+'.pkl.gz', 'wb') 108 | cPickle.dump(result, f) 109 | f.close() 110 | 111 | # Reshape digits 112 | def reshape_images(x, shape): 113 | def rebin(_a, shape): 114 | sh = shape[0],_a.shape[0]//shape[0],shape[1],_a.shape[1]//shape[1] 115 | result = _a.reshape(sh).mean(-1).mean(1) 116 | return np.floor(result).astype(np.uint8) 117 | nrows = x.shape[0] 118 | result = np.zeros((nrows, shape[0], shape[1]), dtype=np.uint8) 119 | for i in range(nrows): 120 | result[i,:,:] = rebin(x[i,:,:], shape) 121 | return result 122 | 123 | # Converts integer labels to binarized labels (1-of-K coding) 124 | def binarize_labels(y, n_classes=5): 125 | new_y = np.zeros((n_classes, y.shape[0])) 126 | for i in range(y.shape[0]): 127 | new_y[y[i], i] = 1 128 | return new_y 129 | -------------------------------------------------------------------------------- /utils/paramgraphics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from scipy.misc import imsave 4 | 5 | def scale_max_min(images, max_p, min_p): 6 | # scale the images according to the max and min 7 | # images f x n, column major 8 | ret = np.zeros(images.shape) 9 | for i in xrange(images.shape[1]): 10 | # clips at first 11 | tmp = np.clip(images[:,i], min_p[i], max_p[i]) 12 | # scale 13 | ret[:,i] = (tmp - min_p[i]) / (max_p[i] - min_p[i]) 14 | 15 | return ret 16 | 17 | def scale_to_unit_interval(ndar, eps=1e-8): 18 | """ Scales all values in the ndarray ndar to be between 0 and 1 """ 19 | ndar = ndar.copy() 20 | ndar -= ndar.min() 21 | ndar *= 1.0 / (ndar.max() + eps) 22 | return ndar 23 | 24 | def tile_raster_images(X, img_shape, tile_shape, tile_spacing=(0, 0), 25 | scale=True, 26 | output_pixel_vals=True, 27 | colorImg=False): 28 | """ 29 | Transform an array with one flattened image per row, into an array in 30 | which images are reshaped and layed out like tiles on a floor. 31 | 32 | This function is useful for visualizing datasets whose rows are images, 33 | and also columns of matrices for transforming those rows 34 | (such as the first layer of a neural net). 35 | 36 | :type X: a 2-D ndarray or a tuple of 4 channels, elements of which can 37 | be 2-D ndarrays or None; 38 | :param X: a 2-D array in which every row is a flattened image. 39 | 40 | :type img_shape: tuple; (height, width) 41 | :param img_shape: the original shape of each image 42 | 43 | :type tile_shape: tuple; (rows, cols) 44 | :param tile_shape: the number of images to tile (rows, cols) 45 | 46 | :param output_pixel_vals: if output should be pixel values (i.e. int8 47 | values) or floats 48 | 49 | :param scale_rows_to_unit_interval: if the values need to be scaled before 50 | being plotted to [0,1] or not 51 | 52 | 53 | :returns: array suitable for viewing as an image. 54 | """ 55 | X = X * 1.0 # converts ints to floats 56 | 57 | if colorImg: 58 | channelSize = X.shape[1]/3 59 | X = (X[:,0:channelSize], X[:,channelSize:2*channelSize], X[:,2*channelSize:3*channelSize], None) 60 | 61 | assert len(img_shape) == 2 62 | assert len(tile_shape) == 2 63 | assert len(tile_spacing) == 2 64 | 65 | # The expression below can be re-written in a more C style as 66 | # follows : 67 | # 68 | # out_shape = [0,0] 69 | # out_shape[0] = (img_shape[0] + tile_spacing[0]) * tile_shape[0] - 70 | # tile_spacing[0] 71 | # out_shape[1] = (img_shape[1] + tile_spacing[1]) * tile_shape[1] - 72 | # tile_spacing[1] 73 | out_shape = [(ishp + tsp) * tshp - tsp for ishp, tshp, tsp 74 | in zip(img_shape, tile_shape, tile_spacing)] 75 | 76 | if isinstance(X, tuple): 77 | assert len(X) == 4 78 | # Create an output np ndarray to store the image 79 | if output_pixel_vals: 80 | out_array = np.zeros((out_shape[0], out_shape[1], 4), dtype='uint8') 81 | else: 82 | out_array = np.zeros((out_shape[0], out_shape[1], 4), dtype=X.dtype) 83 | 84 | #colors default to 0, alpha defaults to 1 (opaque) 85 | if output_pixel_vals: 86 | channel_defaults = [0, 0, 0, 255] 87 | else: 88 | channel_defaults = [0., 0., 0., 1.] 89 | 90 | 91 | for i in xrange(4): 92 | if X[i] is None: 93 | # if channel is None, fill it with zeros of the correct 94 | # dtype 95 | out_array[:, :, i] = np.zeros(out_shape, 96 | dtype='uint8' if output_pixel_vals else out_array.dtype 97 | ) + channel_defaults[i] 98 | else: 99 | # use a recurrent call to compute the channel and store it 100 | # in the output 101 | xi = X[i] 102 | if scale: 103 | xi = (X[i] - X[i].min()) / (X[i].max() - X[i].min()) 104 | out_array[:, :, i] = tile_raster_images(xi, img_shape, tile_shape, tile_spacing, False, output_pixel_vals) 105 | 106 | 107 | return out_array 108 | 109 | else: 110 | # if we are dealing with only one channel 111 | H, W = img_shape 112 | Hs, Ws = tile_spacing 113 | 114 | # generate a matrix to store the output 115 | out_array = np.zeros(out_shape, dtype='uint8' if output_pixel_vals else X.dtype) 116 | 117 | 118 | for tile_row in xrange(tile_shape[0]): 119 | for tile_col in xrange(tile_shape[1]): 120 | if tile_row * tile_shape[1] + tile_col < X.shape[0]: 121 | if scale: 122 | # if we should scale values to be between 0 and 1 123 | # do this by calling the `scale_to_unit_interval` 124 | # function 125 | tmp = X[tile_row * tile_shape[1] + tile_col].reshape(img_shape) 126 | this_img = scale_to_unit_interval(tmp) 127 | else: 128 | this_img = X[tile_row * tile_shape[1] + tile_col].reshape(img_shape) 129 | # add the slice to the corresponding position in the 130 | # output array 131 | out_array[ 132 | tile_row * (H+Hs): tile_row * (H + Hs) + H, 133 | tile_col * (W+Ws): tile_col * (W + Ws) + W 134 | ] \ 135 | = this_img * (255 if output_pixel_vals else 1) 136 | return out_array 137 | 138 | # Matrix to image 139 | def mat_to_img(w, dim_input, scale=False, colorImg=False, tile_spacing=(1,1), tile_shape=0, save_path=None): 140 | if tile_shape == 0: 141 | rowscols = int(w.shape[1]**0.5) 142 | tile_shape = (rowscols,rowscols) 143 | imgs = tile_raster_images(X=w.T, img_shape=dim_input, tile_shape=tile_shape, tile_spacing=tile_spacing, scale=scale, colorImg=colorImg) 144 | if save_path is not None: 145 | imsave(save_path, imgs) 146 | return imgs -------------------------------------------------------------------------------- /layers/pool_special.py: -------------------------------------------------------------------------------- 1 | import lasagne 2 | from lasagne import init 3 | from lasagne import nonlinearities 4 | 5 | import theano.tensor as T 6 | import theano 7 | import numpy as np 8 | import theano.tensor.extra_ops as Textra 9 | 10 | 11 | __all__ = [ 12 | "UnPoolLayer", # upsampling by setting input to the top-left corner 13 | "RepeatUnPoolLayer", # upsampling by repeating input 14 | "UnPoolMaskLayer", # upsampling with pooling location 15 | "MaxPoolLocationLayer", # get the location of max pooling 16 | ] 17 | 18 | class UnPoolLayer(lasagne.layers.Layer): 19 | ''' 20 | Layer that upsampling the input 21 | 22 | Parameters 23 | ---------- 24 | incoming: class `Layer` instance 25 | dim of incoming: B,C,0,1 26 | 27 | factor : tuple of length 2 28 | upsample factor 29 | ---------- 30 | ''' 31 | def __init__(self, incoming, factor, **kwargs): 32 | super(UnPoolLayer, self).__init__(incoming, **kwargs) 33 | assert len(factor) == 2 34 | assert len(self.input_shape) == 4 35 | self.factor = factor 36 | window = np.zeros(self.factor, dtype=np.float32) 37 | window[0, 0] = 1 38 | image_shape = self.input_shape[1:] 39 | self.mask = theano.shared(np.tile(window.reshape((1,)+self.factor), image_shape)) 40 | self.mask = T.shape_padleft(self.mask,n_ones=1) 41 | 42 | def get_output_shape_for(self, input_shape): 43 | return input_shape[:2] + (input_shape[2]*self.factor[0], input_shape[3]*self.factor[1]) 44 | 45 | def get_output_for(self, input, **kwargs): 46 | return Textra.repeat(Textra.repeat(input,self.factor[0],axis=2),self.factor[1],axis=3)*self.mask 47 | 48 | class RepeatUnPoolLayer(lasagne.layers.Layer): 49 | ''' 50 | Layer that upsampling the input 51 | one unit in the input corresponds a square of units in the output 52 | all values in the region are same as the corresponding value of input 53 | 54 | Parameters 55 | ---------- 56 | incoming: class `Layer` instance 57 | dim of incoming: B,C,0,1 58 | 59 | factor : tuple of length 2 60 | upsample factor 61 | ---------- 62 | ''' 63 | def __init__(self, incoming, factor, **kwargs): 64 | super(RepeatUnPoolLayer, self).__init__(incoming, **kwargs) 65 | assert len(factor) == 2 66 | assert len(self.input_shape) == 4 67 | self.factor = factor 68 | 69 | def get_output_shape_for(self, input_shape): 70 | return input_shape[:2] + (input_shape[2]*self.factor[0], input_shape[3]*self.factor[1]) 71 | 72 | def get_output_for(self, input, **kwargs): 73 | return Textra.repeat(Textra.repeat(input,self.factor[0],axis=2),self.factor[1],axis=3) 74 | 75 | class UnPoolMaskLayer(lasagne.layers.MergeLayer): 76 | ''' 77 | Layer that upsampling the input given the pooling location 78 | 79 | Parameters 80 | ---------- 81 | incoming, mask : class `Layer` instances 82 | dim of incoming: B,C,0,1 83 | dim of mask: B,C,0*f1,1*f2 84 | 85 | factor : tuple of length 2 86 | upsample factor 87 | ---------- 88 | ''' 89 | def __init__(self, incoming, mask, factor, noise_level=0.7, **kwargs): 90 | super(UnPoolMaskLayer, self).__init__([incoming, mask], **kwargs) 91 | assert len(factor) == 2 92 | assert len(self.input_shapes[0]) == 4 93 | assert len(self.input_shapes[1]) == 4 94 | assert self.input_shapes[0][2]*factor[0] == self.input_shapes[1][2] 95 | assert self.input_shapes[0][3]*factor[1] == self.input_shapes[1][3] 96 | assert noise_level>=0 and noise_level<=1 97 | self.factor = factor 98 | self.noise = noise_level 99 | 100 | def get_output_shape_for(self, input_shapes): 101 | return input_shapes[1] 102 | 103 | def get_output_for(self, input, **kwargs): 104 | data, mask_max = input 105 | #return Textra.repeat(Textra.repeat(data, self.factor[0], axis=2), self.factor[1], axis=3) * mask_max 106 | window = np.zeros(self.factor, dtype=np.float32) 107 | window[0, 0] = 1 108 | mask_unpool = np.tile(window.reshape((1,) + self.factor), self.input_shapes[0][1:]) 109 | mask_unpool = T.shape_padleft(mask_unpool, n_ones=1) 110 | 111 | rs = np.random.RandomState(1234) 112 | rng = theano.tensor.shared_randomstreams.RandomStreams(rs.randint(999999)) 113 | mask_binomial = rng.binomial(n=1, p=self.noise, size= self.input_shapes[1][1:]) 114 | mask_binomial = T.shape_padleft(T.cast(mask_binomial, dtype='float32'), n_ones=1) 115 | 116 | mask = mask_binomial * mask_unpool + (1 - mask_binomial) * mask_max 117 | return Textra.repeat(Textra.repeat(data,self.factor[0],axis=2),self.factor[1],axis=3)*mask 118 | 119 | class MaxPoolLocationLayer(lasagne.layers.Layer): 120 | ''' 121 | Layer that computes the max-pool location 122 | 123 | Parameters 124 | ---------- 125 | incoming : a class `Layer` instance 126 | output shape is 4D 127 | 128 | factor : tuple of length 2 129 | downsample, fixed to (2, 2) so far 130 | 131 | batch_size : tensor iscalar 132 | 133 | References 134 | ---------- 135 | ''' 136 | def __init__(self, incoming, factor, batch_size, noise_level=0.5, **kwargs): 137 | super(MaxPoolLocationLayer, self).__init__(incoming, **kwargs) 138 | assert factor[0] == 2, factor # only for special (2,2) case 139 | assert factor[1] == 2, factor 140 | self.factor = factor 141 | self.batch_size = batch_size 142 | self.n_channels = self.input_shape[1] 143 | self.i_s = self.input_shape[-2:] 144 | self.noise = noise_level 145 | 146 | def get_output_shape_for(self, input_shape): 147 | return input_shape 148 | 149 | def _get_output_for(self, input): 150 | assert input.ndim == 3 # only for 3D 151 | mask = T.zeros_like(input) # size (None, w, h) 152 | tmp = T.concatenate([T.shape_padright(input[:, ::2, ::2]), 153 | T.shape_padright(input[:, ::2, 1::2]), T.shape_padright(input[:, 1::2, ::2]), 154 | T.shape_padright(input[:, 1::2, 1::2])], axis=-1) 155 | index = tmp.argmax(axis=-1) # size (None, w/2, h/2) 156 | i_r = 2*(np.tile(np.arange(self.i_s[0]/2), (self.i_s[1]/2,1))).T 157 | i_r = index/2 + T.shape_padleft(i_r) 158 | i_c = 2*(np.tile(np.arange(self.i_s[1]/2), (self.i_s[0]/2,1))) 159 | i_c = index%2 + T.shape_padleft(i_c) 160 | i_b = T.tile(T.arange(self.batch_size*self.n_channels),(self.i_s[0]/2*self.i_s[1]/2,1)).T 161 | mask = T.set_subtensor(mask[i_b.flatten(), i_r.flatten(), i_c.flatten()],1) 162 | return mask 163 | 164 | def get_output_for(self, input, **kwargs): 165 | assert input.ndim == 4 # only for 4D 166 | input_3D = input.reshape((self.batch_size*self.n_channels,)+self.i_s) 167 | mask_max = self._get_output_for(input_3D) 168 | return mask_max.reshape((self.batch_size,self.n_channels)+self.i_s) 169 | 170 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle as pkl 3 | import cPickle as cPkl 4 | import gzip 5 | import tarfile 6 | import fnmatch 7 | import os 8 | import urllib 9 | from scipy.io import loadmat 10 | 11 | def _unpickle(f): 12 | import cPickle 13 | fo = open(f, 'rb') 14 | d = cPickle.load(fo) 15 | fo.close() 16 | return d 17 | 18 | def _get_datafolder_path(): 19 | #full_path = os.path.abspath('.') 20 | #path = full_path +'/data' 21 | path = '/home/chongxuan/mfs/data' 22 | return path 23 | 24 | def _download_svhn(datasets_dir=_get_datafolder_path()+'/svhn/'): 25 | url = 'http://ufldl.stanford.edu/housenumbers/' 26 | data_file_list = ['train_32x32.mat', 'test_32x32.mat', 'extra_32x32.mat'] 27 | 28 | if not os.path.exists(datasets_dir): 29 | os.makedirs(datasets_dir) 30 | 31 | for data_file in data_file_list: 32 | if not os.path.isfile(os.path.join(datasets_dir,data_file)): 33 | urllib.urlretrieve(os.path.join(url,data_file), data_file) 34 | 35 | batch1_data = [] 36 | batch1_labels = [] 37 | batch2_data = [] 38 | batch2_labels = [] 39 | from random import shuffle 40 | 41 | train = loadmat(os.path.join(datasets_dir,data_file_list[0])) 42 | x = train['X'].transpose((2, 0, 1, 3)).reshape((3072, -1)) 43 | y = train['y'].reshape((-1,)) 44 | for i in np.arange(len(y)): 45 | if y[i] == 10: 46 | y[i] = 0 47 | index = np.arange(len(y)) 48 | shuffle(index) 49 | x = x[:, index] 50 | y = y[index] 51 | 52 | count = np.zeros((10,), 'int32') 53 | for i in np.arange(len(y)): 54 | if count[y[i]] < 400: 55 | count[y[i]] += 1 56 | batch2_data.append(x[:, i]) 57 | batch2_labels.append(y[i]) 58 | else: 59 | batch1_data.append(x[:, i]) 60 | batch1_labels.append(y[i]) 61 | 62 | print '---train' 63 | extra = loadmat(os.path.join(datasets_dir,data_file_list[2])) 64 | x = extra['X'].transpose((2, 0, 1, 3)).reshape((3072, -1)) 65 | y = extra['y'].reshape((-1,)) 66 | del extra 67 | for i in np.arange(len(y)): 68 | if y[i] == 10: 69 | y[i] = 0 70 | index = np.arange(len(y)) 71 | shuffle(index) 72 | x = x[:, index] 73 | y = y[index] 74 | 75 | count = np.zeros((10,), 'int32') 76 | for i in np.arange(len(y)): 77 | if count[y[i]] < 200: 78 | count[y[i]] += 1 79 | batch2_data.append(x[:, i]) 80 | batch2_labels.append(y[i]) 81 | else: 82 | batch1_data.append(x[:, i]) 83 | batch1_labels.append(y[i]) 84 | batch1_data = np.asarray(batch1_data) 85 | batch2_data = np.asarray(batch2_data) 86 | batch1_labels = np.asarray(batch1_labels) 87 | batch2_labels = np.asarray(batch2_labels) 88 | del x, y 89 | 90 | print '---extra' 91 | 92 | test = loadmat(os.path.join(datasets_dir,data_file_list[1])) 93 | x = test['X'].transpose((2, 0, 1, 3)).reshape((3072, -1)) 94 | y = test['y'].reshape((-1,)) 95 | for i in np.arange(len(y)): 96 | if y[i] == 10: 97 | y[i] = 0 98 | batch3_data = x 99 | batch3_labels = [] 100 | for i in np.arange(len(y)): 101 | batch3_labels.append(y[i]) 102 | batch3_data = np.asarray(batch3_data).T 103 | batch3_labels = np.asarray(batch3_labels) 104 | 105 | print 'Check n x f' 106 | print batch1_data.shape 107 | print batch1_labels.shape 108 | print batch2_data.shape 109 | print batch2_labels.shape 110 | print batch3_data.shape 111 | print batch3_labels.shape 112 | 113 | f = file(datasets_dir+"/svhn.bin","wb") 114 | np.save(f,batch1_data) 115 | np.save(f,batch1_labels) 116 | np.save(f,batch2_data) 117 | np.save(f,batch2_labels) 118 | np.save(f,batch3_data) 119 | np.save(f,batch3_labels) 120 | f.close() 121 | 122 | def load_svhn(datasets_dir=_get_datafolder_path()+'/svhn/', normalized=True, centered=True): 123 | data_file = os.path.join(datasets_dir, 'svhn.bin') 124 | 125 | if not os.path.exists(datasets_dir): 126 | os.makedirs(datasets_dir) 127 | 128 | if not os.path.isfile(data_file): 129 | _download_svhn() 130 | 131 | f = file(data_file,"rb") 132 | train_x = np.load(f) 133 | train_y = np.load(f) 134 | valid_x = np.load(f) 135 | valid_y = np.load(f) 136 | test_x = np.load(f) 137 | test_y = np.load(f) 138 | f.close() 139 | if normalized: 140 | train_x = train_x/256.0 141 | valid_x = valid_x/256.0 142 | test_x = test_x/256.0 143 | 144 | avg = None 145 | if centered: 146 | avg = train_x.mean(axis=0,keepdims=True) 147 | train_x = train_x - avg 148 | test_x = test_x - avg 149 | valid_x = valid_x - avg 150 | 151 | return train_x, train_y, valid_x, valid_y, test_x, test_y, avg 152 | 153 | def load_cifar10(datasets_dir=_get_datafolder_path()+'/cifar10', num_val=None, normalized=True, centered=True): 154 | # this code is largely cp from Kyle Kastner: 155 | # 156 | # https://gist.github.com/kastnerkyle/f3f67424adda343fef40 157 | 158 | url = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' 159 | data_file = os.path.join(datasets_dir, 'cifar-10-python.tar.gz') 160 | data_dir = os.path.join(datasets_dir, 'cifar-10-batches-py') 161 | 162 | if not os.path.exists(datasets_dir): 163 | os.makedirs(datasets_dir) 164 | 165 | if not os.path.isfile(data_file): 166 | urllib.urlretrieve(url, data_file) 167 | org_dir = os.getcwd() 168 | with tarfile.open(data_file) as tar: 169 | os.chdir(datasets_dir) 170 | tar.extractall() 171 | os.chdir(org_dir) 172 | 173 | train_files = [] 174 | for filepath in fnmatch.filter(os.listdir(data_dir), 'data*'): 175 | train_files.append(os.path.join(data_dir, filepath)) 176 | train_files = sorted(train_files, key=lambda x: x.split("_")[-1]) 177 | 178 | test_file = os.path.join(data_dir, 'test_batch') 179 | 180 | x_train, targets_train = [], [] 181 | for f in train_files: 182 | d = _unpickle(f) 183 | x_train.append(d['data']) 184 | targets_train.append(d['labels']) 185 | x_train = np.array(x_train, dtype='uint8') 186 | shp = x_train.shape 187 | x_train = x_train.reshape(shp[0] * shp[1], 3, 32, 32) 188 | targets_train = np.array(targets_train) 189 | targets_train = targets_train.ravel() 190 | 191 | d = _unpickle(test_file) 192 | x_test = d['data'] 193 | targets_test = d['labels'] 194 | x_test = np.array(x_test, dtype='uint8') 195 | x_test = x_test.reshape(-1, 3, 32, 32) 196 | targets_test = np.array(targets_test) 197 | targets_test = targets_test.ravel() 198 | 199 | if normalized: 200 | x_train = x_train/256.0 201 | x_test = x_test/256.0 202 | if centered: 203 | avg = x_train.mean(axis=0,keepdims=True) 204 | x_train = x_train - avg 205 | x_test = x_test - avg 206 | 207 | if num_val is not None: 208 | perm = np.random.permutation(x_train.shape[0]) 209 | x = x_train[perm] 210 | y = targets_train[perm] 211 | 212 | x_valid = x[:num_val] 213 | targets_valid = y[:num_val] 214 | x_train = x[num_val:] 215 | targets_train = y[num_val:] 216 | return (x_train, targets_train, 217 | x_valid, targets_valid, 218 | x_test, targets_test) 219 | else: 220 | return x_train, targets_train, x_test, targets_test -------------------------------------------------------------------------------- /components/objectives.py: -------------------------------------------------------------------------------- 1 | ''' 2 | objectives 3 | ''' 4 | import numpy as np 5 | import theano.tensor as T 6 | import theano 7 | import lasagne 8 | 9 | from parmesan.distributions import log_stdnormal, log_normal2, log_bernoulli 10 | 11 | def margin_for_reinforce(predictions, num_labelled, delta=1): 12 | num_cls = predictions.shape[1] 13 | predictions=predictions[num_labelled:] # predictions U x nc 14 | p_max = T.max(predictions, axis=1) 15 | p_mean = T.mean(predictions, axis=1) 16 | margin = (p_max - p_mean) / (num_cls - 1) * num_cls 17 | return margin 18 | 19 | def margin_for_reinforce1(predictions, num_labelled, delta=1): 20 | num_cls = predictions.shape[1] 21 | predictions=predictions[num_labelled:] # predictions U x nc 22 | p_max = T.max(predictions, axis=1) 23 | p_mean = T.mean(predictions, axis=1) 24 | margin = (p_max - p_mean) / (num_cls - 1) * num_cls 25 | return margin 26 | 27 | def lowerbound_for_reinforce(z, z_mu, z_log_var, x_mu, x, num_features, num_labelled, num_classes, epsilon=1e-6): 28 | x = x.reshape((-1,num_features)) 29 | x_mu = x_mu.reshape((-1,num_features)) 30 | 31 | log_qz_given_xy = log_normal2(z, z_mu, z_log_var).sum(axis=1) 32 | log_pz = log_stdnormal(z).sum(axis=1) 33 | log_py = T.log(1.0/num_classes) 34 | log_px_given_zy = log_bernoulli(x, T.clip(x_mu, epsilon, 1 - epsilon)).sum(axis=1) 35 | ll_xy = log_px_given_zy + log_pz + log_py - log_qz_given_xy 36 | return ll_xy[num_labelled:] 37 | 38 | def multiclass_s3vm_loss(predictions, targets, num_labelled, weight_decay, norm_type=2, form ='mean_class', alpha_hinge=1., alpha_hat=1., alpha_reg=1., alpha_decay=1., delta=1., entropy_term=False): 39 | ''' 40 | predictions: 41 | size L x nc 42 | U x nc 43 | targets: 44 | size L x nc 45 | 46 | output: 47 | weighted sum of hinge loss, hat loss, balance constraint and weight decay 48 | ''' 49 | num_cls = predictions.shape[1] 50 | if targets.ndim == predictions.ndim - 1: 51 | targets = theano.tensor.extra_ops.to_one_hot(targets, num_cls) 52 | elif targets.ndim != predictions.ndim: 53 | raise TypeError('rank mismatch between targets and predictions') 54 | 55 | hinge_loss = multiclass_hinge_loss_(predictions[:num_labelled], targets, delta) 56 | hat_loss = multiclass_hat_loss(predictions[num_labelled:], delta) 57 | regularization = balance_constraint(predictions, targets, num_labelled, norm_type, form) 58 | if not entropy_term: 59 | return alpha_hinge*hinge_loss.mean() + alpha_hat*hat_loss.mean() + alpha_reg*regularization + alpha_decay*weight_decay 60 | else: 61 | # given an unlabeled data, when treat hat loss as the entropy term derived from a lowerbound, it should conflict to current prediction, which is quite strange but true ... the entropy term enforce the discriminator to predict unlabeled data uniformly as a regularization 62 | # max entropy regularization provides a tighter lowerbound but hurt the semi-supervised learning performance as it conflicts to the hat loss ... 63 | return alpha_hinge*hinge_loss.mean() - alpha_hat*hat_loss.mean() + alpha_reg*regularization + alpha_decay*weight_decay 64 | 65 | def multiclass_hinge_loss_(predictions, targets, delta=1): 66 | return lasagne.objectives.multiclass_hinge_loss(predictions, targets, delta) 67 | 68 | def multiclass_hinge_loss(predictions, targets, weight_decay, alpha_decay=1., delta=1): 69 | return multiclass_hinge_loss_(predictions, targets, delta).mean() + alpha_decay*weight_decay 70 | 71 | def multiclass_hat_loss(predictions, delta=1): 72 | targets = T.argmax(predictions, axis=1) 73 | return multiclass_hinge_loss(predictions, targets, delta) 74 | 75 | def balance_constraint(predictions, targets, num_labelled, norm_type=2, form='mean_class'): 76 | ''' 77 | balance constraint 78 | ------ 79 | norm_type: type of norm 80 | l2 or l1 81 | form: form of regularization 82 | mean_class: average mean activation of u and l data should be the same over each class 83 | mean_all: average mean activation of u and l data should be the same over all data 84 | ratio: 85 | 86 | ''' 87 | p_l = predictions[:num_labelled] 88 | p_u = predictions[num_labelled:] 89 | t_l = targets 90 | t_u = T.argmax(p_u, axis=1) 91 | num_cls = predictions.shape[1] 92 | t_u = theano.tensor.extra_ops.to_one_hot(t_u, num_cls) 93 | if form == 'mean_class': 94 | res = (p_l*t_l).mean(axis=0) - (p_u*t_u).mean(axis=0) 95 | elif form == 'mean_all': 96 | res = p_l.mean(axis=0) - p_u.mean(axis=0) 97 | elif form == 'ratio': 98 | pass 99 | 100 | # res should be a vector with length number_class 101 | return res.norm(norm_type) 102 | 103 | def latent_gaussian_x_gaussian(z, z_mu, z_log_var, x_mu, x_log_var, x, latent_size, num_features, eq_samples, iw_samples, epsilon=1e-6): 104 | # reshape the variables so batch_size, eq_samples and iw_samples are separate dimensions 105 | z = z.reshape((-1, eq_samples, iw_samples, latent_size)) 106 | x_mu = x_mu.reshape((-1, eq_samples, iw_samples, num_features)) 107 | x_log_var = x_log_var.reshape((-1, eq_samples, iw_samples, num_features)) 108 | 109 | # dimshuffle x, z_mu and z_log_var since we need to broadcast them when calculating the pdfs 110 | x = x.reshape((-1,num_features)) 111 | x = x.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_features) 112 | z_mu = z_mu.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_latent) 113 | z_log_var = z_log_var.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_latent) 114 | 115 | # calculate LL components, note that the log_xyz() functions return log prob. for indepenedent components separately 116 | # so we sum over feature/latent dimensions for multivariate pdfs 117 | log_qz_given_x = log_normal2(z, z_mu, z_log_var).sum(axis=3) 118 | log_pz = log_stdnormal(z).sum(axis=3) 119 | #log_px_given_z = log_bernoulli(x, T.clip(x_mu, epsilon, 1 - epsilon)).sum(axis=3) 120 | log_px_given_z = log_normal2(x, x_mu, x_log_var).sum(axis=3) 121 | 122 | #all log_*** should have dimension (batch_size, eq_samples, iw_samples) 123 | # Calculate the LL using log-sum-exp to avoid underflow 124 | a = log_pz + log_px_given_z - log_qz_given_x # size: (batch_size, eq_samples, iw_samples) 125 | a_max = T.max(a, axis=2, keepdims=True) # size: (batch_size, eq_samples, 1) 126 | 127 | LL = T.mean(a_max) + T.mean( T.log( T.mean(T.exp(a-a_max), axis=2) ) ) 128 | 129 | return LL, T.mean(log_qz_given_x), T.mean(log_pz), T.mean(log_px_given_z) 130 | 131 | def latent_gaussian_x_bernoulli(z, z_mu, z_log_var, x_mu, x, latent_size, num_features, eq_samples, iw_samples, epsilon=1e-6): 132 | """ 133 | Latent z : gaussian with standard normal prior 134 | decoder output : bernoulli 135 | 136 | When the output is bernoulli then the output from the decoder 137 | should be sigmoid. The sizes of the inputs are 138 | z: (batch_size*eq_samples*iw_samples, num_latent) 139 | z_mu: (batch_size, num_latent) 140 | z_log_var: (batch_size, num_latent) 141 | x_mu: (batch_size*eq_samples*iw_samples, num_features) 142 | x: (batch_size, num_features) 143 | 144 | Reference: Burda et al. 2015 "Importance Weighted Autoencoders" 145 | """ 146 | 147 | # reshape the variables so batch_size, eq_samples and iw_samples are separate dimensions 148 | z = z.reshape((-1, eq_samples, iw_samples, latent_size)) 149 | x_mu = x_mu.reshape((-1, eq_samples, iw_samples, num_features)) 150 | 151 | # dimshuffle x, z_mu and z_log_var since we need to broadcast them when calculating the pdfs 152 | x = x.reshape((-1,num_features)) 153 | x = x.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_features) 154 | z_mu = z_mu.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_latent) 155 | z_log_var = z_log_var.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_latent) 156 | 157 | # calculate LL components, note that the log_xyz() functions return log prob. for indepenedent components separately 158 | # so we sum over feature/latent dimensions for multivariate pdfs 159 | log_qz_given_x = log_normal2(z, z_mu, z_log_var).sum(axis=3) 160 | log_pz = log_stdnormal(z).sum(axis=3) 161 | log_px_given_z = log_bernoulli(x, T.clip(x_mu, epsilon, 1 - epsilon)).sum(axis=3) 162 | 163 | #all log_*** should have dimension (batch_size, eq_samples, iw_samples) 164 | # Calculate the LL using log-sum-exp to avoid underflow 165 | a = log_pz + log_px_given_z - log_qz_given_x # size: (batch_size, eq_samples, iw_samples) 166 | a_max = T.max(a, axis=2, keepdims=True) # size: (batch_size, eq_samples, 1) 167 | 168 | # LL is calculated using Eq (8) in Burda et al. 169 | # Working from inside out of the calculation below: 170 | # T.exp(a-a_max): (batch_size, eq_samples, iw_samples) 171 | # -> subtract a_max to avoid overflow. a_max is specific for each set of 172 | # importance samples and is broadcasted over the last dimension. 173 | # 174 | # T.log( T.mean(T.exp(a-a_max), axis=2) ): (batch_size, eq_samples) 175 | # -> This is the log of the sum over the importance weighted samples 176 | # 177 | # The outer T.mean() computes the mean over eq_samples and batch_size 178 | # 179 | # Lastly we add T.mean(a_max) to correct for the log-sum-exp trick 180 | LL = T.mean(a_max) + T.mean( T.log( T.mean(T.exp(a-a_max), axis=2) ) ) 181 | 182 | return LL, T.mean(log_qz_given_x), T.mean(log_pz), T.mean(log_px_given_z) 183 | -------------------------------------------------------------------------------- /cdgm_x2y_xy2z_zy2x_sl.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code implements max-margin conditional deep generative model which incorporates the side information in generative modelling and uses a discriminative classifier to infer the latent labels 3 | for supervised learning 4 | ''' 5 | 6 | import gzip, os, cPickle, time, math, argparse, shutil, sys 7 | 8 | import numpy as np 9 | import theano.tensor as T 10 | import theano 11 | import lasagne 12 | from parmesan.datasets import load_mnist_realval, load_mnist_binarized, load_frey_faces 13 | from datasets import load_cifar10, load_svhn 14 | from parmesan.layers import SampleLayer 15 | 16 | from layers.merge import ConvConcatLayer, MLPConcatLayer 17 | from utils.others import get_nonlin_list, get_pad_list, bernoullisample, build_log_file, printarray_2D, array2file_2D 18 | from components.shortcuts import convlayer, fractionalstridedlayer, unpoolconvlayer, mlplayer 19 | from components.objectives import latent_gaussian_x_gaussian, latent_gaussian_x_bernoulli 20 | from components.objectives import multiclass_s3vm_loss, multiclass_hinge_loss 21 | import utils.paramgraphics as paramgraphics 22 | 23 | ''' 24 | parameters 25 | ''' 26 | # global 27 | theano.config.floatX = 'float32' 28 | filename_script = os.path.basename(os.path.realpath(__file__)) 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("-dataset", type=str, default="mnist_real") 31 | parser.add_argument("-outfolder", type=str, default=os.path.join("results-ssl", os.path.splitext(filename_script)[0])) 32 | parser.add_argument("-preprocess", type=str, default="none") 33 | parser.add_argument("-subset_flag", type=str, default ='false') 34 | # architecture 35 | parser.add_argument("-nz", type=int, default=100) 36 | parser.add_argument("-batch_norm_dgm", type=str, default='false') 37 | parser.add_argument("-top_mlp", type=str, default='false') 38 | parser.add_argument("-mlp_size", type=int, default=256) 39 | parser.add_argument("-batch_norm_classifier", type=str, default='false') 40 | # classifier 41 | parser.add_argument("-batch_size", type=int, default=200) 42 | parser.add_argument("-delta", type=float, default=1.) 43 | parser.add_argument("-alpha_decay", type=float, default=1e-4) 44 | parser.add_argument("-alpha", type=float, default=.1) 45 | parser.add_argument("-norm_type", type=int, default=2) 46 | parser.add_argument("-form", type=str, default='mean_class') 47 | # feature extractor 48 | parser.add_argument("-nlayers_cla", type=int, default=3) 49 | parser.add_argument("-nk_cla", type=str, default='32,64,128') 50 | parser.add_argument("-dk_cla", type=str, default='4,5,3') 51 | parser.add_argument("-pad_cla", type=str, default='valid,valid,valid') 52 | parser.add_argument("-str_cla", type=str, default='2,2,2') 53 | parser.add_argument("-ps_cla", type=str, default='1,1,1') 54 | parser.add_argument("-nonlin_cla", type=str, default='rectify,rectify,rectify') 55 | parser.add_argument("-dr_cla", type=str, default='0,0,0') 56 | # encoder 57 | parser.add_argument("-nlayers_enc", type=int, default=3) 58 | parser.add_argument("-nk_enc", type=str, default='32,64,128') 59 | parser.add_argument("-dk_enc", type=str, default='4,5,3') 60 | parser.add_argument("-pad_enc", type=str, default='valid,valid,valid') 61 | parser.add_argument("-str_enc", type=str, default='2,2,2') 62 | parser.add_argument("-ps_enc", type=str, default='1,1,1') 63 | parser.add_argument("-nonlin_enc", type=str, default='rectify,rectify,rectify') 64 | parser.add_argument("-dr_enc", type=str, default='0,0,0') 65 | # decoder 66 | parser.add_argument("-nlayers_dec", type=int, default=4) 67 | parser.add_argument("-nk_dec", type=str, default='128,64,32,1') 68 | parser.add_argument("-dk_dec", type=str, default='3,5,4,5') 69 | parser.add_argument("-pad_dec", type=str, default='valid,valid,valid,same') 70 | parser.add_argument("-str_dec", type=str, default='2,2,2,1') 71 | parser.add_argument("-up_method", type=str, default='frac_strided,frac_strided,frac_strided,none') 72 | parser.add_argument("-ps_dec", type=str, default='1,1,1,1') 73 | parser.add_argument("-nonlin_dec", type=str, default='rectify,rectify,rectify,sigmoid') 74 | parser.add_argument("-dr_dec", type=str, default='0,0,0,0') 75 | # optimization 76 | parser.add_argument("-flag", type=str, default='validation') # validation for anneal learning rate 77 | parser.add_argument("-lr", type=float, default=0.0003) 78 | parser.add_argument("-nepochs", type=int, default=200) 79 | parser.add_argument("-anneal_lr_epoch", type=int, default=100) 80 | parser.add_argument("-anneal_lr_factor", type=float, default=.99) 81 | parser.add_argument("-every_anneal", type=int, default=1) 82 | clip_grad = 1 83 | max_norm = 5 84 | # name 85 | parser.add_argument("-name", type=str, default='') 86 | # inference 87 | parser.add_argument("-eq_samples", type=int, 88 | help="number of samples for the expectation over q(z|x)", default=1) 89 | parser.add_argument("-iw_samples", type=int, 90 | help="number of importance weighted samples", default=1) 91 | 92 | # random seeds for reproducibility 93 | np.random.seed(1234) 94 | from theano.tensor.shared_randomstreams import RandomStreams 95 | srng = RandomStreams(seed=1234) 96 | 97 | # get parameters 98 | # global 99 | args = parser.parse_args() 100 | dataset = args.dataset 101 | subset_flag = args.subset_flag == 'true' or args.subset_flag == 'True' 102 | eval_epoch = 1 103 | # architecture 104 | nz = args.nz 105 | bn_dgm = args.batch_norm_dgm == 'true' or args.batch_norm_dgm == 'True' 106 | top_mlp = args.top_mlp == 'true' or args.top_mlp == 'True' 107 | mlp_size = args.mlp_size 108 | bn_cla = args.batch_norm_classifier == 'true' or args.batch_norm_classifier == 'True' 109 | # classifier 110 | batch_size = args.batch_size 111 | delta = args.delta 112 | alpha_decay = args.alpha_decay 113 | alpha = args.alpha 114 | norm_type = args.norm_type 115 | form = args.form 116 | # feature extractor 117 | nlayers_cla = args.nlayers_cla 118 | nk_cla = map(int, args.nk_cla.split(',')) 119 | dk_cla = map(int, args.dk_cla.split(',')) 120 | pad_cla = map(str, args.pad_cla.split(',')) 121 | str_cla = map(int, args.str_cla.split(',')) 122 | ps_cla = map(int, args.ps_cla.split(',')) 123 | dr_cla = map(float, args.dr_cla.split(',')) 124 | nonlin_cla = get_nonlin_list(map(str, args.nonlin_cla.split(','))) 125 | # encoder 126 | nlayers_enc = args.nlayers_enc 127 | nk_enc = map(int, args.nk_enc.split(',')) 128 | dk_enc = map(int, args.dk_enc.split(',')) 129 | pad_enc = get_pad_list(map(str, args.pad_enc.split(','))) 130 | str_enc = map(int, args.str_enc.split(',')) 131 | ps_enc = map(int, args.ps_enc.split(',')) 132 | dr_enc = map(float, args.dr_enc.split(',')) 133 | nonlin_enc = get_nonlin_list(map(str, args.nonlin_enc.split(','))) 134 | # decoder 135 | nlayers_dec = args.nlayers_dec 136 | nk_dec = map(int, args.nk_dec.split(',')) 137 | dk_dec = map(int, args.dk_dec.split(',')) 138 | pad_dec = get_pad_list(map(str, args.pad_dec.split(','))) 139 | str_dec = map(int, args.str_dec.split(',')) 140 | ps_dec = map(int, args.ps_dec.split(',')) 141 | dr_dec = map(float, args.dr_dec.split(',')) 142 | nonlin_dec = get_nonlin_list(map(str, args.nonlin_dec.split(','))) 143 | up_method = map(str, args.up_method.split(',')) 144 | # optimization 145 | flag = args.flag 146 | lr = args.lr 147 | num_epochs = args.nepochs 148 | anneal_lr_epoch = args.anneal_lr_epoch 149 | anneal_lr_factor = args.anneal_lr_factor 150 | every_anneal = args.every_anneal 151 | # inference 152 | iw_samples = args.iw_samples 153 | eq_samples = args.eq_samples 154 | # log file 155 | logfile, res_out = build_log_file(args, filename_script) 156 | shutil.copy(os.path.realpath(__file__), os.path.join(res_out, filename_script)) 157 | 158 | ''' 159 | datasets 160 | ''' 161 | if dataset == 'mnist_real': 162 | colorImg = False 163 | dim_input = (28,28) 164 | in_channels = 1 165 | num_classes = 10 166 | generation_scale = False 167 | num_generation = 100 168 | vis_epoch = 100 169 | distribution = 'bernoulli' 170 | num_features = in_channels*dim_input[0]*dim_input[1] 171 | print "Using real-valued mnist dataset" 172 | train_x, train_t, valid_x, valid_t, test_x, test_t = load_mnist_realval() 173 | if flag == 'validation': 174 | test_x = valid_x 175 | test_t = valid_t 176 | else: 177 | train_x = np.concatenate([train_x,valid_x]) 178 | train_t = np.hstack((train_t, valid_t)) 179 | train_x_size = train_t.shape[0] 180 | train_t = np.int32(train_t) 181 | test_t = np.int32(test_t) 182 | train_x = train_x.astype(theano.config.floatX) 183 | test_x = test_x.astype(theano.config.floatX) 184 | train_x = train_x.reshape((-1, in_channels)+dim_input) 185 | test_x = test_x.reshape((-1, in_channels)+dim_input) 186 | 187 | elif dataset == 'cifar10': 188 | colorImg = True 189 | dim_input = (32,32) 190 | in_channels = 3 191 | num_classes = 10 192 | generation_scale = False 193 | num_generation = 100 194 | vis_epoch = 100 195 | distribution = 'bernoulli' 196 | num_features = in_channels*dim_input[0]*dim_input[1] 197 | print "Using cifar10 dataset" 198 | train_x, train_t, valid_x, valid_t, test_x, test_t = load_cifar10(num_val=5000, normalized=True, centered=True) 199 | if flag == 'validation': 200 | test_x = valid_x 201 | test_t = valid_t 202 | else: 203 | train_x = np.concatenate([train_x,valid_x]) 204 | train_t = np.hstack((train_t, valid_t)) 205 | train_x_size = train_t.shape[0] 206 | train_t = np.int32(train_t) 207 | test_t = np.int32(test_t) 208 | train_x = train_x.astype(theano.config.floatX) 209 | test_x = test_x.astype(theano.config.floatX) 210 | train_x = train_x.reshape((-1, in_channels)+dim_input) 211 | test_x = test_x.reshape((-1, in_channels)+dim_input) 212 | 213 | elif dataset == 'svhn': 214 | colorImg = True 215 | dim_input = (32,32) 216 | in_channels = 3 217 | num_classes = 10 218 | generation_scale = False 219 | num_generation = 100 220 | vis_epoch = 10 221 | distribution = 'bernoulli' 222 | num_features = in_channels*dim_input[0]*dim_input[1] 223 | print "Using svhn dataset" 224 | train_x, train_t, valid_x, valid_t, test_x, test_t = load_svhn(normalized=True, centered=False) 225 | if flag == 'validation': 226 | test_x = valid_x 227 | test_t = valid_t 228 | else: 229 | train_x = np.concatenate([train_x,valid_x]) 230 | train_t = np.hstack((train_t, valid_t)) 231 | train_x_size = train_t.shape[0] 232 | train_t = np.int32(train_t) 233 | test_t = np.int32(test_t) 234 | train_x = train_x.astype(theano.config.floatX) 235 | test_x = test_x.astype(theano.config.floatX) 236 | train_x = train_x.reshape((-1, in_channels)+dim_input) 237 | test_x = test_x.reshape((-1, in_channels)+dim_input) 238 | 239 | # preprocess 240 | 241 | preprocesses_dataset = lambda dataset: dataset 242 | sh_x_train = theano.shared(preprocesses_dataset(train_x), borrow=True) 243 | sh_t_train = theano.shared(train_t, borrow=True) 244 | sh_x_test = theano.shared(preprocesses_dataset(test_x), borrow=True) 245 | sh_t_test = theano.shared(test_t, borrow=True) 246 | 247 | ''' 248 | building block 249 | ''' 250 | # shortcuts 251 | encodelayer = convlayer 252 | 253 | # decoder layer 254 | def decodelayer(l,up_method,bn,dr,ps,n_kerns,d_kerns,nonlinearity,pad,stride,name): 255 | # upsampling 256 | if up_method == 'unpool': 257 | h_g = unpoolconvlayer(l,bn,dr,ps,n_kerns,d_kerns,nonlinearity,pad,stride,name,'unpool',None) 258 | elif up_method == 'frac_strided': 259 | h_g = fractionalstridedlayer(l,bn,dr,n_kerns,d_kerns,nonlinearity,pad,stride,name) 260 | elif up_method == 'none': 261 | h_g, _ = convlayer(l,bn,dr,ps,n_kerns,d_kerns,nonlinearity,pad,stride,name) 262 | else: 263 | raise Exception('Unknown upsampling method') 264 | return h_g 265 | 266 | 267 | ''' 268 | model 269 | ''' 270 | # symbolic variables 271 | sym_iw_samples = T.iscalar('iw_samples') 272 | sym_eq_samples = T.iscalar('eq_samples') 273 | sym_lr = T.scalar('lr') 274 | sym_x = T.tensor4('x') 275 | sym_x_cla = T.tensor4('x_cla') 276 | sym_y = T.ivector('y') 277 | sym_index = T.iscalar('index') 278 | sym_batch_size = T.iscalar('batch_size') 279 | batch_slice = slice(sym_index * sym_batch_size, (sym_index + 1) * sym_batch_size) 280 | 281 | # x2y 282 | l_in_x_cla = lasagne.layers.InputLayer((None, in_channels)+dim_input) 283 | l_cla = [l_in_x_cla,] 284 | print lasagne.layers.get_output_shape(l_cla[-1]) 285 | # conv layers 286 | for i in xrange(nlayers_cla): 287 | l, _= convlayer(l_cla[-1],bn_cla,dr_cla[i],ps_cla[i],nk_cla[i],dk_cla[i],nonlin_cla[i],pad_cla[i],str_cla[i],'CLA-'+str(i+1)) 288 | l_cla.append(l) 289 | print lasagne.layers.get_output_shape(l_cla[-1]) 290 | 291 | # feature and classifier 292 | if top_mlp: 293 | l_cla.append(lasagne.layers.FlattenLayer(l_cla[-1])) 294 | feature = mlplayer(l_cla[-1],bn_cla,0.5,mlp_size,lasagne.nonlinearities.rectify,name='MLP-CLA') 295 | else: 296 | feature = lasagne.layers.GlobalPoolLayer(l_cla[-1]) 297 | classifier = lasagne.layers.DenseLayer(feature, num_units=num_classes, nonlinearity=lasagne.nonlinearities.identity, W=lasagne.init.Normal(1e-2, 0), name="classifier") 298 | 299 | # encoder xy2z 300 | l_in_x = lasagne.layers.InputLayer((None, in_channels)+dim_input) 301 | l_in_y = lasagne.layers.InputLayer((None,)) 302 | l_enc = [l_in_x,] 303 | for i in xrange(nlayers_enc): 304 | l_enc.append(ConvConcatLayer([l_enc[-1], l_in_y], num_classes)) 305 | l, _ = encodelayer(l_enc[-1],bn_dgm,dr_enc[i],ps_enc[i],nk_enc[i],dk_enc[i],nonlin_enc[i],pad_enc[i],str_enc[i],'ENC-'+str(i+1),False,0) 306 | l_enc.append(l) 307 | print lasagne.layers.get_output_shape(l_enc[-1]) 308 | 309 | # reshape 310 | after_conv_shape = lasagne.layers.get_output_shape(l_enc[-1]) 311 | after_conv_size = int(np.prod(after_conv_shape[1:])) 312 | l_enc.append(lasagne.layers.FlattenLayer(l_enc[-1])) 313 | print lasagne.layers.get_output_shape(l_enc[-1]) 314 | 315 | # compute parameters and sample z 316 | l_mu = mlplayer(l_enc[-1],False,0,nz,lasagne.nonlinearities.identity,'ENC-MU') 317 | l_log_var = mlplayer(l_enc[-1],False,0,nz,lasagne.nonlinearities.identity,'ENC-LOG_VAR') 318 | l_z = SampleLayer(mean=l_mu,log_var=l_log_var,eq_samples=sym_eq_samples,iw_samples=sym_iw_samples) 319 | 320 | # decoder zy2x 321 | l_dec = [l_z,] 322 | print lasagne.layers.get_output_shape(l_dec[-1]) 323 | 324 | # reshape 325 | l_dec.append(mlplayer(l_dec[-1],bn_dgm,0,after_conv_size,lasagne.nonlinearities.rectify, 'DEC_l_Z')) 326 | print lasagne.layers.get_output_shape(l_dec[-1]) 327 | l_dec.append(lasagne.layers.ReshapeLayer(l_dec[-1], shape=(-1,)+after_conv_shape[1:])) 328 | print lasagne.layers.get_output_shape(l_dec[-1]) 329 | for i in (xrange(nlayers_dec-1)): 330 | l_dec.append(ConvConcatLayer([l_dec[-1], l_in_y], num_classes)) 331 | l = decodelayer(l_dec[-1],up_method[i],False,dr_dec[i],ps_dec[i],nk_dec[i],dk_dec[i],nonlin_dec[i],pad_dec[i],str_dec[i],'DEC-'+str(i+1)) 332 | l_dec.append(l) 333 | print lasagne.layers.get_output_shape(l_dec[-1]) 334 | 335 | # mu and var 336 | if distribution == 'gaussian': 337 | l_dec_x_mu = decodelayer(l_dec[-1],up_method[-1],False,dr_dec[-1],ps_dec[-1],nk_dec[-1],dk_dec[-1],lasagne.nonlinearities.identity,pad_dec[-1],str_dec[-1],'DEC-MU') 338 | l_dec_x_log_var = decodelayer(l_dec[-1],up_method[-1],False,dr_dec[-1],ps_dec[-1],nk_dec[-1],dk_dec[-1],lasagne.nonlinearities.identity,pad_dec[-1],str_dec[-1],'DEC-LOG_VAR') 339 | elif distribution == 'bernoulli': 340 | l_dec_x_mu = decodelayer(l_dec[-1],up_method[-1],False,dr_dec[-1],ps_dec[-1],nk_dec[-1],dk_dec[-1],lasagne.nonlinearities.sigmoid,pad_dec[-1],str_dec[-1],'DEC-MU') 341 | print lasagne.layers.get_output_shape(l_dec_x_mu) 342 | 343 | # predictions and accuracies 344 | predictions_train = lasagne.layers.get_output(classifier, sym_x_cla, deterministic=False) 345 | predictions_eval = lasagne.layers.get_output(classifier, sym_x_cla, deterministic=True) 346 | accurracy_train = lasagne.objectives.categorical_accuracy(predictions_train, sym_y) 347 | accurracy_eval = lasagne.objectives.categorical_accuracy(predictions_eval, sym_y) 348 | 349 | # weight decays 350 | weight_decay_classifier = lasagne.regularization.regularize_layer_params_weighted({classifier:1}, lasagne.regularization.l2) 351 | 352 | 353 | ''' 354 | learning 355 | ''' 356 | # discriminative objective 357 | classifier_cost_train = multiclass_hinge_loss(predictions=predictions_train, targets=sym_y, weight_decay=weight_decay_classifier, alpha_decay=alpha_decay) 358 | classifier_cost_eval = multiclass_hinge_loss(predictions=predictions_eval, targets=sym_y, weight_decay=weight_decay_classifier, alpha_decay=alpha_decay) 359 | 360 | cost_cla = classifier_cost_train 361 | 362 | # generative objective 363 | predictions_train_hard = predictions_train.argmax(axis=1) 364 | predictions_eval_hard = predictions_eval.argmax(axis=1) 365 | 366 | if distribution == 'bernoulli': 367 | z_train, z_mu_train, z_log_var_train, x_mu_train = lasagne.layers.get_output([l_z, l_mu, l_log_var, l_dec_x_mu], {l_in_x:sym_x,l_in_y:predictions_train_hard}, deterministic=False) 368 | z_eval, z_mu_eval, z_log_var_eval, x_mu_eval = lasagne.layers.get_output([l_z, l_mu, l_log_var, l_dec_x_mu], {l_in_x:sym_x, l_in_y:predictions_eval_hard}, deterministic=True) 369 | 370 | # lower bounds 371 | LL_train, log_qz_given_xy_train, log_pz_train, log_px_given_zy_train = latent_gaussian_x_bernoulli(z_train, z_mu_train, z_log_var_train, x_mu_train, sym_x, latent_size=nz, num_features=num_features, eq_samples=sym_eq_samples, iw_samples=sym_iw_samples) 372 | LL_eval, log_qz_given_xy_eval, log_pz_eval, log_px_given_zy_eval = latent_gaussian_x_bernoulli(z_eval, z_mu_eval, z_log_var_eval, x_mu_eval, sym_x, latent_size=nz, num_features=num_features, eq_samples=sym_eq_samples, iw_samples=sym_iw_samples) 373 | 374 | elif distribution == 'gaussian': 375 | z_train, z_mu_train, z_log_var_train, x_mu_train, x_log_var_train = lasagne.layers.get_output([l_z, l_mu, l_log_var, l_dec_x_mu, l_dec_x_log_var], {l_in_x:sym_x, l_in_y:predictions_train_hard}, deterministic=False) 376 | z_eval, z_mu_eval, z_log_var_eval, x_mu_eval, x_log_var_eval = lasagne.layers.get_output([l_z, l_mu, l_log_var, l_dec_x_mu, l_dec_x_log_var], {l_in_x:sym_x, l_in_y:predictions_eval_hard}, deterministic=True) 377 | 378 | LL_train, log_qz_given_xy_train, log_pz_train, log_px_given_zy_train = latent_gaussian_x_gaussian(z_train, z_mu_train, z_log_var_train, x_mu_train, x_log_var_train, sym_x, latent_size=nz, num_features=num_features, eq_samples=sym_eq_samples, iw_samples=sym_iw_samples) 379 | LL_eval, log_qz_given_xy_eval, log_pz_eval, log_px_given_zy_eval = latent_gaussian_x_gaussian(z_eval, z_mu_eval, z_log_var_eval, x_mu_eval, x_log_var_eval, sym_x, latent_size=nz, num_features=num_features, eq_samples=sym_eq_samples, iw_samples=sym_iw_samples) 380 | 381 | cost_gen = -LL_train 382 | cost = cost_gen + alpha*cost_cla 383 | 384 | # count parameters 385 | if distribution == 'bernoulli': 386 | params = lasagne.layers.get_all_params([classifier, l_dec_x_mu], trainable=True) 387 | for p in params: 388 | print p, p.get_value().shape 389 | params_count = lasagne.layers.count_params([classifier,l_dec_x_mu], trainable=True) 390 | elif distribution == 'gaussian': 391 | params = lasagne.layers.get_all_params([classifier,l_dec_x_mu, l_dec_x_log_var], trainable=True) 392 | for p in params: 393 | print p, p.get_value().shape 394 | params_count = lasagne.layers.count_params([classifier,l_dec_x_mu, l_dec_x_log_var], trainable=True) 395 | print 'Number of parameters:', params_count 396 | 397 | # functions 398 | grads = T.grad(cost, params) 399 | # mgrads = lasagne.updates.total_norm_constraint(grads,max_norm=max_norm) 400 | # cgrads = [T.clip(g, -clip_grad, clip_grad) for g in mgrads] 401 | updates = lasagne.updates.adam(grads, params, beta1=0.9, beta2=0.999, epsilon=1e-8, learning_rate=sym_lr) 402 | 403 | train_model = theano.function([sym_index, sym_batch_size, sym_lr, sym_eq_samples, sym_iw_samples], [LL_train, log_qz_given_xy_train, log_pz_train, log_px_given_zy_train, classifier_cost_train, accurracy_train], givens={sym_x_cla: sh_x_train[batch_slice], sym_x: sh_x_train[batch_slice], sym_y: sh_t_train[batch_slice]}, updates=updates) 404 | test_model = theano.function([sym_index, sym_batch_size, sym_eq_samples, sym_iw_samples], [LL_eval, log_qz_given_xy_eval, log_pz_eval, log_px_given_zy_eval, classifier_cost_eval, accurracy_eval], givens={sym_x_cla: sh_x_test[batch_slice], sym_x: sh_x_test[batch_slice], sym_y: sh_t_test[batch_slice]}) 405 | 406 | 407 | # random generation for visualization 408 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 409 | srng_ran = RandomStreams(lasagne.random.get_rng().randint(1, 2147462579)) 410 | srng_ran_share = theano.tensor.shared_randomstreams.RandomStreams(1234) 411 | sym_ran_y = T.ivector('ran_y') 412 | 413 | ran_z = T.tile(srng_ran.normal((10,nz)), (10, 1)) 414 | if distribution == 'bernoulli': 415 | random_x_mean = lasagne.layers.get_output(l_dec_x_mu, {l_z:ran_z, l_in_y:sym_ran_y}, deterministic=True) 416 | random_x = srng_ran_share.binomial(n=1, p=random_x_mean, dtype=theano.config.floatX) 417 | elif distribution == 'gaussian': 418 | random_x_mean, random_x_log_var = lasagne.layers.get_output([l_dec_x_mu, l_dec_x_log_var], {l_z:ran_z}, deterministic=True) 419 | random_x = srng_ran_share.normal(avg=random_x_mean, std=T.exp(0.5*random_x_log_var)) 420 | generate = theano.function(inputs=[sym_ran_y], outputs=[random_x_mean, random_x]) 421 | 422 | 423 | ''' 424 | run 425 | ''' 426 | # Training and Testing functions 427 | def train_epoch(lr, eq_samples, iw_samples, batch_size): 428 | costs,log_qz_given_xy,log_pz,log_px_given_zy, loss, accurracy = [],[],[],[],[],[] 429 | n_train_batches = train_x.shape[0] / (batch_size) 430 | 431 | for i in range(n_train_batches): 432 | costs_batch, log_qz_given_xy_batch,log_pz_batch,log_px_given_zy_batch, loss_batch, accurracy_batch = train_model(i, batch_size, lr, eq_samples, iw_samples) 433 | costs += [costs_batch] 434 | log_qz_given_xy += [log_qz_given_xy_batch] 435 | log_pz += [log_pz_batch] 436 | log_px_given_zy += [log_px_given_zy_batch] 437 | loss += [loss_batch] 438 | accurracy += [accurracy_batch] 439 | return np.mean(costs), np.mean(log_qz_given_xy), np.mean(log_pz), np.mean(log_px_given_zy), np.mean(loss), np.mean(accurracy) 440 | 441 | def test_epoch(eq_samples, iw_samples, batch_size): 442 | n_test_batches = test_x.shape[0] / batch_size 443 | costs,log_qz_given_xy,log_pz,log_px_given_zy,loss,accurracy = [],[],[],[],[],[] 444 | for i in range(n_test_batches): 445 | costs_batch, log_qz_given_xy_batch,log_pz_batch,log_px_given_zy_batch, loss_batch, accurracy_batch = test_model(i, batch_size, eq_samples, iw_samples) 446 | costs += [costs_batch] 447 | log_qz_given_xy += [log_qz_given_xy_batch] 448 | log_pz += [log_pz_batch] 449 | log_px_given_zy += [log_px_given_zy_batch] 450 | loss += [loss_batch] 451 | accurracy += [accurracy_batch] 452 | return np.mean(costs), np.mean(log_qz_given_xy), np.mean(log_pz), np.mean(log_px_given_zy), np.mean(loss), np.mean(accurracy) 453 | 454 | 455 | print "Training" 456 | 457 | # TRAIN LOOP 458 | LL_train, log_qz_given_x_train, log_pz_train, log_px_given_z_train, loss_train, acc_train = [],[],[],[],[],[] 459 | LL_test, log_qz_given_x_test, log_pz_test, log_px_given_z_test, loss_test, acc_test = [],[],[],[],[],[] 460 | 461 | for epoch in range(1, 1+num_epochs): 462 | start = time.time() 463 | 464 | # randomly permute data and labels 465 | p = np.random.permutation(train_x.shape[0]) 466 | sh_x_train.set_value(preprocesses_dataset(train_x[p])) 467 | sh_t_train.set_value(train_t[p]) 468 | 469 | train_out = train_epoch(lr, eq_samples, iw_samples, batch_size) 470 | 471 | if np.isnan(train_out[0]): 472 | ValueError("NAN in train LL!") 473 | 474 | if epoch >= anneal_lr_epoch and epoch % every_anneal == 0: 475 | #annealing learning rate 476 | lr = lr*anneal_lr_factor 477 | 478 | if epoch % eval_epoch == 0: 479 | t = time.time() - start 480 | LL_train += [train_out[0]] 481 | log_qz_given_x_train += [train_out[1]] 482 | log_pz_train += [train_out[2]] 483 | log_px_given_z_train += [train_out[3]] 484 | loss_train +=[train_out[4]] 485 | acc_train += [train_out[5]] 486 | 487 | print "calculating LL eq=1, iw=1" 488 | test_out = test_epoch(eq_samples, iw_samples, batch_size=500) 489 | LL_test += [test_out[0]] 490 | log_qz_given_x_test += [test_out[1]] 491 | log_pz_test += [test_out[2]] 492 | log_px_given_z_test += [test_out[3]] 493 | loss_test += [test_out[4]] 494 | acc_test += [test_out[5]] 495 | 496 | 497 | line = "*Epoch=%d\tTime=%.2f\tLR=%.5f\n" %(epoch, t, lr) + \ 498 | " TRAIN:\tGen_loss=%.5f\tlogq(z|x)=%.5f\tlogp(z)=%.5f\tlogp(x|z)=%.5f\tdis_loss=%.5f\tlabel_error=%.5f\n" %(LL_train[-1], log_qz_given_x_train[-1], log_pz_train[-1], log_px_given_z_train[-1], loss_train[-1], 1-acc_train[-1]) + \ 499 | " EVAL-L1:\tGen_loss=%.5f\tlogq(z|x)=%.5f\tlogp(z)=%.5f\tlogp(x|z)=%.5f\tdis_loss=%.5f\terror=%.5f\n" %(LL_test[-1], log_qz_given_x_test[-1], log_pz_test[-1], log_px_given_z_test[-1], loss_test[-1], 1-acc_test[-1]) 500 | print line 501 | with open(logfile,'a') as f: 502 | f.write(line + "\n") 503 | 504 | # random generation for visualization 505 | if epoch % vis_epoch == 0: 506 | tail='-'+str(epoch)+'.png' 507 | ran_y = np.int32(np.repeat(np.arange(10), 10)) 508 | _x_mean, _x = generate(ran_y) 509 | _x_mean = _x_mean.reshape((100,-1)) 510 | _x = _x.reshape((100,-1)) 511 | image = paramgraphics.mat_to_img(_x_mean.T, dim_input, colorImg=colorImg, scale=generation_scale, 512 | save_path=os.path.join(res_out, 'mean'+tail)) 513 | 514 | #save model 515 | model_out = os.path.join(res_out, 'model') 516 | if epoch % (vis_epoch*10) == 0: 517 | if distribution == 'bernoulli': 518 | all_params=lasagne.layers.get_all_params([classifier, l_dec_x_mu]) 519 | elif distribution == 'gaussian': 520 | all_params=lasagne.layers.get_all_params([classifier, l_dec_x_mu, l_dec_x_log_var]) 521 | f = gzip.open(model_out + 'epoch%i'%(epoch), 'wb') 522 | cPickle.dump(all_params, f, protocol=cPickle.HIGHEST_PROTOCOL) 523 | f.close() -------------------------------------------------------------------------------- /cdgm_x2y_xy2z_zy2x.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code implements max-margin deep conditional generative model which incorporates the side information in generative modelling and uses a semi-supervised classifier to infer the latent labels 3 | ''' 4 | 5 | import gzip, os, cPickle, time, math, argparse, shutil, sys 6 | 7 | import numpy as np 8 | import theano.tensor as T 9 | import theano 10 | import lasagne 11 | from theano.tensor.extra_ops import to_one_hot 12 | from parmesan.datasets import load_mnist_realval, load_mnist_binarized, load_frey_faces, load_norb_small 13 | from datasets import load_cifar10, load_svhn 14 | from datasets_norb import load_numpy_subclasses 15 | from parmesan.layers import SampleLayer 16 | 17 | from layers.merge import ConvConcatLayer, MLPConcatLayer 18 | from utils.others import get_nonlin_list, get_pad_list, bernoullisample, build_log_file, printarray_2D, array2file_2D 19 | from components.shortcuts import convlayer, fractionalstridedlayer, unpoolconvlayer, mlplayer 20 | from components.objectives import latent_gaussian_x_gaussian, latent_gaussian_x_bernoulli 21 | from components.objectives import multiclass_s3vm_loss, multiclass_hinge_loss 22 | from utils.create_ssl_data import create_ssl_data, create_ssl_data_subset 23 | import utils.paramgraphics as paramgraphics 24 | 25 | ''' 26 | parameters 27 | ''' 28 | # global 29 | theano.config.floatX = 'float32' 30 | filename_script = os.path.basename(os.path.realpath(__file__)) 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("-dataset", type=str, default="mnist_real") 33 | parser.add_argument("-outfolder", type=str, default=os.path.join("results-ssl", os.path.splitext(filename_script)[0])) 34 | parser.add_argument("-preprocess", type=str, default="none") 35 | parser.add_argument("-subset_flag", type=str, default ='false') 36 | # architecture 37 | parser.add_argument("-nz", type=int, default=100) 38 | parser.add_argument("-batch_norm_dgm", type=str, default='false') 39 | parser.add_argument("-top_mlp", type=str, default='false') 40 | parser.add_argument("-mlp_size", type=int, default=256) 41 | parser.add_argument("-batch_norm_classifier", type=str, default='false') 42 | # classifier 43 | parser.add_argument("-num_labelled", type=int, default=100) 44 | parser.add_argument("-num_labelled_per_batch", type=int, default=100) 45 | parser.add_argument("-batch_size", type=int, default=200) 46 | parser.add_argument("-delta", type=float, default=1.) 47 | parser.add_argument("-alpha_decay", type=float, default=1e-4) 48 | parser.add_argument("-alpha_hinge", type=float, default=1.) 49 | parser.add_argument("-alpha_hat", type=float, default=.3) 50 | parser.add_argument("-alpha_reg", type=float, default=0) 51 | parser.add_argument("-alpha", type=float, default=.1) 52 | parser.add_argument("-alpha_straight_through", type=float, default=1e-4) 53 | parser.add_argument("-norm_type", type=int, default=2) 54 | parser.add_argument("-form", type=str, default='mean_class') 55 | # feature extractor 56 | parser.add_argument("-nlayers_cla", type=int, default=3) 57 | parser.add_argument("-nk_cla", type=str, default='32,64,128') 58 | parser.add_argument("-dk_cla", type=str, default='4,5,3') 59 | parser.add_argument("-pad_cla", type=str, default='valid,valid,valid') 60 | parser.add_argument("-str_cla", type=str, default='2,2,2') 61 | parser.add_argument("-ps_cla", type=str, default='1,1,1') 62 | parser.add_argument("-nonlin_cla", type=str, default='rectify,rectify,rectify') 63 | parser.add_argument("-dr_cla", type=str, default='0,0,0') 64 | # encoder 65 | parser.add_argument("-nlayers_enc", type=int, default=3) 66 | parser.add_argument("-nk_enc", type=str, default='32,64,128') 67 | parser.add_argument("-dk_enc", type=str, default='4,5,3') 68 | parser.add_argument("-pad_enc", type=str, default='valid,valid,valid') 69 | parser.add_argument("-str_enc", type=str, default='2,2,2') 70 | parser.add_argument("-ps_enc", type=str, default='1,1,1') 71 | parser.add_argument("-nonlin_enc", type=str, default='rectify,rectify,rectify') 72 | parser.add_argument("-dr_enc", type=str, default='0,0,0') 73 | # decoder 74 | parser.add_argument("-nlayers_dec", type=int, default=4) 75 | parser.add_argument("-nk_dec", type=str, default='128,64,32,1') 76 | parser.add_argument("-dk_dec", type=str, default='3,5,4,5') 77 | parser.add_argument("-pad_dec", type=str, default='valid,valid,valid,same') 78 | parser.add_argument("-str_dec", type=str, default='2,2,2,1') 79 | parser.add_argument("-up_method", type=str, default='frac_strided,frac_strided,frac_strided,none') 80 | parser.add_argument("-ps_dec", type=str, default='1,1,1,1') 81 | parser.add_argument("-nonlin_dec", type=str, default='rectify,rectify,rectify,sigmoid') 82 | parser.add_argument("-dr_dec", type=str, default='0,0,0,0') 83 | # optimization 84 | parser.add_argument("-flag", type=str, default='validation') # validation for anneal learning rate 85 | parser.add_argument("-ssl_data_seed", type=int, default=0) # random seed for ssl data generation 86 | parser.add_argument("-lr", type=float, default=0.0003) 87 | parser.add_argument("-nepochs", type=int, default=200) 88 | parser.add_argument("-anneal_lr_epoch", type=int, default=100) 89 | parser.add_argument("-anneal_lr_factor", type=float, default=.99) 90 | parser.add_argument("-every_anneal", type=int, default=1) 91 | clip_grad = 1 92 | max_norm = 5 93 | # name 94 | parser.add_argument("-name", type=str, default='') 95 | # inference 96 | parser.add_argument("-eq_samples", type=int, 97 | help="number of samples for the expectation over q(z|x)", default=1) 98 | parser.add_argument("-iw_samples", type=int, 99 | help="number of importance weighted samples", default=1) 100 | 101 | # random seeds for reproducibility 102 | np.random.seed(1234) 103 | from theano.tensor.shared_randomstreams import RandomStreams 104 | srng = RandomStreams(seed=1234) 105 | 106 | # get parameters 107 | # global 108 | args = parser.parse_args() 109 | dataset = args.dataset 110 | subset_flag = args.subset_flag == 'true' or args.subset_flag == 'True' 111 | eval_epoch = 1 112 | # architecture 113 | nz = args.nz 114 | bn_dgm = args.batch_norm_dgm == 'true' or args.batch_norm_dgm == 'True' 115 | top_mlp = args.top_mlp == 'true' or args.top_mlp == 'True' 116 | mlp_size = args.mlp_size 117 | bn_cla = args.batch_norm_classifier == 'true' or args.batch_norm_classifier == 'True' 118 | # classifier 119 | num_labelled = args.num_labelled 120 | batch_size = args.batch_size 121 | num_labelled_per_batch = args.num_labelled_per_batch 122 | assert num_labelled % num_labelled_per_batch == 0 123 | delta = args.delta 124 | alpha_straight_through = args.alpha_straight_through 125 | alpha_decay = args.alpha_decay 126 | alpha_hinge = args.alpha_hinge 127 | alpha_reg = args.alpha_reg 128 | alpha_hat = args.alpha_hat 129 | alpha = args.alpha 130 | norm_type = args.norm_type 131 | form = args.form 132 | # feature extractor 133 | nlayers_cla = args.nlayers_cla 134 | nk_cla = map(int, args.nk_cla.split(',')) 135 | dk_cla = map(int, args.dk_cla.split(',')) 136 | pad_cla = map(str, args.pad_cla.split(',')) 137 | str_cla = map(int, args.str_cla.split(',')) 138 | ps_cla = map(int, args.ps_cla.split(',')) 139 | dr_cla = map(float, args.dr_cla.split(',')) 140 | nonlin_cla = get_nonlin_list(map(str, args.nonlin_cla.split(','))) 141 | # encoder 142 | nlayers_enc = args.nlayers_enc 143 | nk_enc = map(int, args.nk_enc.split(',')) 144 | dk_enc = map(int, args.dk_enc.split(',')) 145 | pad_enc = get_pad_list(map(str, args.pad_enc.split(','))) 146 | str_enc = map(int, args.str_enc.split(',')) 147 | ps_enc = map(int, args.ps_enc.split(',')) 148 | dr_enc = map(float, args.dr_enc.split(',')) 149 | nonlin_enc = get_nonlin_list(map(str, args.nonlin_enc.split(','))) 150 | # decoder 151 | nlayers_dec = args.nlayers_dec 152 | nk_dec = map(int, args.nk_dec.split(',')) 153 | dk_dec = map(int, args.dk_dec.split(',')) 154 | pad_dec = get_pad_list(map(str, args.pad_dec.split(','))) 155 | str_dec = map(int, args.str_dec.split(',')) 156 | ps_dec = map(int, args.ps_dec.split(',')) 157 | dr_dec = map(float, args.dr_dec.split(',')) 158 | nonlin_dec = get_nonlin_list(map(str, args.nonlin_dec.split(','))) 159 | up_method = map(str, args.up_method.split(',')) 160 | # optimization 161 | flag = args.flag 162 | ssl_data_seed = args.ssl_data_seed 163 | if ssl_data_seed == -1: 164 | ssl_data_seed = int(time.time()) 165 | lr = args.lr 166 | num_epochs = args.nepochs 167 | anneal_lr_epoch = args.anneal_lr_epoch 168 | anneal_lr_factor = args.anneal_lr_factor 169 | every_anneal = args.every_anneal 170 | # inference 171 | iw_samples = args.iw_samples 172 | eq_samples = args.eq_samples 173 | # log file 174 | logfile, res_out = build_log_file(args, filename_script, extra=str(args.ssl_data_seed)) 175 | shutil.copy(os.path.realpath(__file__), os.path.join(res_out, filename_script)) 176 | 177 | ''' 178 | datasets 179 | ''' 180 | if dataset == 'mnist_real': 181 | colorImg = False 182 | dim_input = (28,28) 183 | in_channels = 1 184 | num_classes = 10 185 | generation_scale = False 186 | num_generation = num_classes*num_classes 187 | vis_epoch = 100 188 | distribution = 'bernoulli' 189 | num_features = in_channels*dim_input[0]*dim_input[1] 190 | print "Using real-valued mnist dataset" 191 | train_x, train_t, valid_x, valid_t, test_x, test_t = load_mnist_realval() 192 | if flag == 'validation': 193 | test_x = valid_x 194 | test_t = valid_t 195 | else: 196 | train_x = np.concatenate([train_x,valid_x]) 197 | train_t = np.hstack((train_t, valid_t)) 198 | train_x_size = train_t.shape[0] 199 | train_t = np.int32(train_t) 200 | test_t = np.int32(test_t) 201 | train_x = train_x.astype(theano.config.floatX) 202 | test_x = test_x.astype(theano.config.floatX) 203 | train_x = train_x.reshape((-1, in_channels)+dim_input) 204 | test_x = test_x.reshape((-1, in_channels)+dim_input) 205 | # prepare data for semi-supervised learning 206 | if subset_flag: 207 | # instead of sampling from 60000 data, sample 100 data for 10 times to make sure that the labelled data with smaller size is a subset of that with larger size. 208 | x_labelled, y_labelled, x_unlabelled, _ = create_ssl_data_subset(train_x, train_t, num_classes, num_labelled, 100, ssl_data_seed) 209 | else: 210 | x_labelled, y_labelled, x_unlabelled, _ = create_ssl_data(train_x, train_t, num_classes, num_labelled, ssl_data_seed) 211 | y_labelled = np.int32(y_labelled) 212 | elif dataset == 'cifar10': 213 | colorImg = True 214 | dim_input = (32,32) 215 | in_channels = 3 216 | num_classes = 10 217 | generation_scale = False 218 | num_generation = num_classes*num_classes 219 | vis_epoch = 100 220 | distribution = 'bernoulli' 221 | num_features = in_channels*dim_input[0]*dim_input[1] 222 | print "Using cifar10 dataset" 223 | train_x, train_t, valid_x, valid_t, test_x, test_t = load_cifar10(num_val=5000, normalized=True, centered=True) 224 | if flag == 'validation': 225 | test_x = valid_x 226 | test_t = valid_t 227 | else: 228 | train_x = np.concatenate([train_x,valid_x]) 229 | train_t = np.hstack((train_t, valid_t)) 230 | train_x_size = train_t.shape[0] 231 | train_t = np.int32(train_t) 232 | test_t = np.int32(test_t) 233 | train_x = train_x.astype(theano.config.floatX) 234 | test_x = test_x.astype(theano.config.floatX) 235 | train_x = train_x.reshape((-1, in_channels)+dim_input) 236 | test_x = test_x.reshape((-1, in_channels)+dim_input) 237 | # prepare data for semi-supervised learning 238 | x_labelled, y_labelled, x_unlabelled, _ = create_ssl_data(train_x, train_t, num_classes, num_labelled, ssl_data_seed) 239 | y_labelled = np.int32(y_labelled) 240 | elif dataset == 'svhn': 241 | colorImg = True 242 | dim_input = (32,32) 243 | in_channels = 3 244 | num_classes = 10 245 | generation_scale = False 246 | num_generation = num_classes*num_classes 247 | vis_epoch = 10 248 | distribution = 'bernoulli' 249 | num_features = in_channels*dim_input[0]*dim_input[1] 250 | print "Using svhn dataset" 251 | train_x, train_t, valid_x, valid_t, test_x, test_t, avg = load_svhn(normalized=True, centered=False) 252 | if flag == 'validation': 253 | test_x = valid_x 254 | test_t = valid_t 255 | else: 256 | train_x = np.concatenate([train_x,valid_x]) 257 | train_t = np.hstack((train_t, valid_t)) 258 | train_x_size = train_t.shape[0] 259 | train_t = np.int32(train_t) 260 | test_t = np.int32(test_t) 261 | train_x = train_x.astype(theano.config.floatX) 262 | test_x = test_x.astype(theano.config.floatX) 263 | train_x = train_x.reshape((-1, in_channels)+dim_input) 264 | test_x = test_x.reshape((-1, in_channels)+dim_input) 265 | # prepare data for semi-supervised learning 266 | x_labelled, y_labelled, x_unlabelled, _ = create_ssl_data(train_x, train_t, num_classes, num_labelled, ssl_data_seed) 267 | y_labelled = np.int32(y_labelled) 268 | elif dataset == 'norb': 269 | colorImg = False 270 | dim_input = (32,32) 271 | in_channels = 1 272 | num_classes = 5 273 | generation_scale = False 274 | num_generation = num_classes*num_classes 275 | vis_epoch = 100 276 | distribution = 'bernoulli' 277 | num_features = in_channels*dim_input[0]*dim_input[1] 278 | print "Using small norb dataset" 279 | x, t = load_numpy_subclasses(size=dim_input[0], normalize=True, centered=False) 280 | x = np.transpose(x) 281 | t = t.flatten() 282 | train_x = x[:24300] 283 | test_x = x[24300*2:24300*3] 284 | train_t = t[:24300] 285 | test_t = t[24300*2:24300*3] 286 | if flag == 'validation': 287 | test_x = train_x[:1000] 288 | test_t = train_t[:1000] 289 | train_x = train_x[1000:] 290 | train_t = train_t[1000:] 291 | train_x_size = train_t.shape[0] 292 | train_t = np.int32(train_t) 293 | test_t = np.int32(test_t) 294 | train_x = train_x.astype(theano.config.floatX) 295 | test_x = test_x.astype(theano.config.floatX) 296 | train_x = train_x.reshape((-1, in_channels)+dim_input) 297 | test_x = test_x.reshape((-1, in_channels)+dim_input) 298 | # prepare data for semi-supervised learning 299 | x_labelled, y_labelled, x_unlabelled, _ = create_ssl_data(train_x, train_t, num_classes, num_labelled, ssl_data_seed) 300 | y_labelled = np.int32(y_labelled) 301 | 302 | # preprocess 303 | if args.preprocess == 'none': 304 | preprocesses_dataset = None 305 | elif args.preprocess == 'bernoullisample': 306 | preprocesses_dataset = bernoullisample 307 | elif args.preprocess == 'dequantify': 308 | pass 309 | 310 | # shared variables for semi-supervised learning 311 | sh_x_train_labelled = theano.shared(x_labelled, borrow=True) 312 | sh_x_train_unlabelled = theano.shared(x_unlabelled, borrow=True) 313 | sh_t_train_labelled = theano.shared(y_labelled, borrow=True) 314 | sh_x_test = theano.shared(test_x, borrow=True) 315 | sh_t_test = theano.shared(test_t, borrow=True) 316 | if preprocesses_dataset is not None: 317 | sh_x_train_labelled_preprocessed = theano.shared(preprocesses_dataset(x_labelled), borrow=True) 318 | sh_x_train_unlabelled_preprocessed = theano.shared(preprocesses_dataset(x_unlabelled), borrow=True) 319 | sh_x_test_preprocessed = theano.shared(preprocesses_dataset(test_x), borrow=True) 320 | 321 | # visualize labeled data 322 | if True: 323 | print 'size of training data ', x_labelled.shape, y_labelled.shape, x_unlabelled.shape 324 | _x_mean = x_labelled.reshape((num_labelled,-1)) 325 | _x_mean = _x_mean[:num_generation] 326 | y_order = np.argsort(y_labelled[:num_generation]) 327 | _x_mean = _x_mean[y_order] 328 | image = paramgraphics.mat_to_img(_x_mean.T, dim_input, colorImg=colorImg, scale=generation_scale, 329 | save_path=os.path.join(res_out, 'labeled_data'+str(ssl_data_seed)+'.png')) 330 | 331 | ''' 332 | building block 333 | ''' 334 | # shortcuts 335 | encodelayer = convlayer 336 | 337 | # decoder layer 338 | def decodelayer(l,up_method,bn,dr,ps,n_kerns,d_kerns,nonlinearity,pad,stride,name): 339 | # upsampling 340 | if up_method == 'unpool': 341 | h_g = unpoolconvlayer(l,bn,dr,ps,n_kerns,d_kerns,nonlinearity,pad,stride,name,'unpool',None) 342 | elif up_method == 'frac_strided': 343 | h_g = fractionalstridedlayer(l,bn,dr,n_kerns,d_kerns,nonlinearity,pad,stride,name) 344 | elif up_method == 'none': 345 | h_g, _ = convlayer(l,bn,dr,ps,n_kerns,d_kerns,nonlinearity,pad,stride,name) 346 | else: 347 | raise Exception('Unknown upsampling method') 348 | return h_g 349 | 350 | 351 | ''' 352 | model 353 | ''' 354 | # symbolic variables 355 | sym_iw_samples = T.iscalar('iw_samples') 356 | sym_eq_samples = T.iscalar('eq_samples') 357 | sym_lr = T.scalar('lr') 358 | sym_x = T.tensor4('x') 359 | sym_x_cla = T.tensor4('x_cla') 360 | sym_y = T.ivector('y') 361 | sym_index = T.iscalar('index') 362 | sym_batch_size = T.iscalar('batch_size') 363 | batch_slice = slice(sym_index * sym_batch_size, (sym_index + 1) * sym_batch_size) 364 | sym_index_l = T.iscalar('index_l') 365 | sym_index_u = T.iscalar('index_u') 366 | sym_batch_size_l = T.iscalar('batch_size_l') 367 | sym_batch_size_u = T.iscalar('batch_size_u') 368 | batch_slice_l = slice(sym_index_l * sym_batch_size_l, (sym_index_l + 1) * sym_batch_size_l) 369 | batch_slice_u = slice(sym_index_u * sym_batch_size_u, (sym_index_u + 1) * sym_batch_size_u) 370 | 371 | # x2y 372 | l_in_x_cla = lasagne.layers.InputLayer((None, in_channels)+dim_input) 373 | l_cla = [l_in_x_cla,] 374 | print lasagne.layers.get_output_shape(l_cla[-1]) 375 | # conv layers 376 | for i in xrange(nlayers_cla): 377 | l, _= convlayer(l_cla[-1],bn_cla,dr_cla[i],ps_cla[i],nk_cla[i],dk_cla[i],nonlin_cla[i],pad_cla[i],str_cla[i],'CLA-'+str(i+1)) 378 | l_cla.append(l) 379 | print lasagne.layers.get_output_shape(l_cla[-1]) 380 | 381 | # feature and classifier 382 | if top_mlp: 383 | l_cla.append(lasagne.layers.FlattenLayer(l_cla[-1])) 384 | feature = mlplayer(l_cla[-1],bn_cla,0.5,mlp_size,lasagne.nonlinearities.rectify,name='MLP-CLA') 385 | else: 386 | feature = lasagne.layers.GlobalPoolLayer(l_cla[-1]) 387 | classifier = lasagne.layers.DenseLayer(feature, num_units=num_classes, nonlinearity=lasagne.nonlinearities.identity, W=lasagne.init.Normal(1e-2, 0), name="classifier") 388 | 389 | # encoder xy2z 390 | l_in_x = lasagne.layers.InputLayer((None, in_channels)+dim_input) 391 | l_in_y = lasagne.layers.InputLayer((None,)) 392 | l_enc = [l_in_x,] 393 | for i in xrange(nlayers_enc): 394 | l_enc.append(ConvConcatLayer([l_enc[-1], l_in_y], num_classes)) 395 | l, _ = encodelayer(l_enc[-1],bn_dgm,dr_enc[i],ps_enc[i],nk_enc[i],dk_enc[i],nonlin_enc[i],pad_enc[i],str_enc[i],'ENC-'+str(i+1),False,0) 396 | l_enc.append(l) 397 | print lasagne.layers.get_output_shape(l_enc[-1]) 398 | 399 | # reshape 400 | after_conv_shape = lasagne.layers.get_output_shape(l_enc[-1]) 401 | after_conv_size = int(np.prod(after_conv_shape[1:])) 402 | l_enc.append(lasagne.layers.FlattenLayer(l_enc[-1])) 403 | print lasagne.layers.get_output_shape(l_enc[-1]) 404 | 405 | # compute parameters and sample z 406 | l_mu = mlplayer(l_enc[-1],False,0,nz,lasagne.nonlinearities.identity,'ENC-MU') 407 | l_log_var = mlplayer(l_enc[-1],False,0,nz,lasagne.nonlinearities.identity,'ENC-LOG_VAR') 408 | l_z = SampleLayer(mean=l_mu,log_var=l_log_var,eq_samples=sym_eq_samples,iw_samples=sym_iw_samples) 409 | 410 | # decoder zy2x 411 | l_dec = [l_z,] 412 | print lasagne.layers.get_output_shape(l_dec[-1]) 413 | 414 | # reshape 415 | l_dec.append(mlplayer(l_dec[-1],bn_dgm,0,after_conv_size,lasagne.nonlinearities.rectify, 'DEC_l_Z')) 416 | print lasagne.layers.get_output_shape(l_dec[-1]) 417 | l_dec.append(lasagne.layers.ReshapeLayer(l_dec[-1], shape=(-1,)+after_conv_shape[1:])) 418 | print lasagne.layers.get_output_shape(l_dec[-1]) 419 | for i in (xrange(nlayers_dec-1)): 420 | l_dec.append(ConvConcatLayer([l_dec[-1], l_in_y], num_classes)) 421 | l = decodelayer(l_dec[-1],up_method[i],bn_dgm,dr_dec[i],ps_dec[i],nk_dec[i],dk_dec[i],nonlin_dec[i],pad_dec[i],str_dec[i],'DEC-'+str(i+1)) 422 | l_dec.append(l) 423 | print lasagne.layers.get_output_shape(l_dec[-1]) 424 | 425 | # mu and var 426 | if distribution == 'gaussian': 427 | l_dec_x_mu = decodelayer(l_dec[-1],up_method[-1],bn_dgm,dr_dec[-1],ps_dec[-1],nk_dec[-1],dk_dec[-1],lasagne.nonlinearities.sigmoid,pad_dec[-1],str_dec[-1],'DEC-MU') 428 | l_dec_x_log_var = decodelayer(l_dec[-1],up_method[-1],bn_dgm,dr_dec[-1],ps_dec[-1],nk_dec[-1],dk_dec[-1],lasagne.nonlinearities.identity,pad_dec[-1],str_dec[-1],'DEC-LOG_VAR') 429 | elif distribution == 'bernoulli': 430 | l_dec_x_mu = decodelayer(l_dec[-1],up_method[-1],bn_dgm,dr_dec[-1],ps_dec[-1],nk_dec[-1],dk_dec[-1],lasagne.nonlinearities.sigmoid,pad_dec[-1],str_dec[-1],'DEC-MU') 431 | print lasagne.layers.get_output_shape(l_dec_x_mu) 432 | 433 | # predictions and accuracies 434 | predictions_train = lasagne.layers.get_output(classifier, sym_x_cla, deterministic=False) 435 | predictions_eval = lasagne.layers.get_output(classifier, sym_x_cla, deterministic=True) 436 | accurracy_train_labeled = lasagne.objectives.categorical_accuracy(predictions_train[:sym_batch_size_l], sym_y) 437 | accurracy_eval = lasagne.objectives.categorical_accuracy(predictions_eval, sym_y) 438 | 439 | # weight decays 440 | weight_decay_classifier = lasagne.regularization.regularize_layer_params_weighted({classifier:1}, lasagne.regularization.l2) 441 | 442 | ''' 443 | learning 444 | ''' 445 | # discriminative objective 446 | #classifier_cost_train = multiclass_hinge_loss(predictions=predictions_train[:num_labelled], targets=sym_y[:num_labelled], weight_decay=weight_decay_classifier, alpha_decay=alpha_decay) 447 | classifier_cost_train = multiclass_s3vm_loss(predictions=predictions_train, targets=sym_y, weight_decay=weight_decay_classifier, norm_type=norm_type, form=form, num_labelled=sym_batch_size_l, alpha_decay=alpha_decay, alpha_reg=alpha_reg, alpha_hat=alpha_hat, alpha_hinge=alpha_hinge, delta=delta) 448 | classifier_cost_eval = multiclass_hinge_loss(predictions=predictions_eval, targets=sym_y, weight_decay=weight_decay_classifier, alpha_decay=alpha_decay) # no hat loss for testing 449 | 450 | cost_cla = classifier_cost_train 451 | 452 | # generative objective 453 | predictions_train_hard = predictions_train.argmax(axis=1) 454 | predictions_eval_hard = predictions_eval.argmax(axis=1) 455 | 456 | sym_l_in_y_train = to_one_hot(T.concatenate([sym_y,predictions_train_hard[sym_batch_size_l:]], axis=0), num_classes) 457 | if distribution == 'bernoulli': 458 | z_train, z_mu_train, z_log_var_train, x_mu_train = lasagne.layers.get_output([l_z, l_mu, l_log_var, l_dec_x_mu], {l_in_x:sym_x, l_in_y:sym_l_in_y_train}, deterministic=False) 459 | z_eval, z_mu_eval, z_log_var_eval, x_mu_eval = lasagne.layers.get_output([l_z, l_mu, l_log_var, l_dec_x_mu], {l_in_x:sym_x, l_in_y:to_one_hot(predictions_eval_hard, num_classes)}, deterministic=True) 460 | 461 | # lower bounds 462 | LL_train, log_qz_given_xy_train, log_pz_train, log_px_given_zy_train = latent_gaussian_x_bernoulli(z_train, z_mu_train, z_log_var_train, x_mu_train, sym_x, latent_size=nz, num_features=num_features, eq_samples=sym_eq_samples, iw_samples=sym_iw_samples) 463 | LL_eval, log_qz_given_xy_eval, log_pz_eval, log_px_given_zy_eval = latent_gaussian_x_bernoulli(z_eval, z_mu_eval, z_log_var_eval, x_mu_eval, sym_x, latent_size=nz, num_features=num_features, eq_samples=sym_eq_samples, iw_samples=sym_iw_samples) 464 | 465 | elif distribution == 'gaussian': 466 | z_train, z_mu_train, z_log_var_train, x_mu_train, x_log_var_train = lasagne.layers.get_output([l_z, l_mu, l_log_var, l_dec_x_mu, l_dec_x_log_var], {l_in_x:sym_x, l_in_y:sym_l_in_y_train}, deterministic=False) 467 | z_eval, z_mu_eval, z_log_var_eval, x_mu_eval, x_log_var_eval = lasagne.layers.get_output([l_z, l_mu, l_log_var, l_dec_x_mu, l_dec_x_log_var], {l_in_x:sym_x,l_in_y:to_one_hot(predictions_eval_hard, num_classes)}, deterministic=True) 468 | 469 | LL_train, log_qz_given_xy_train, log_pz_train, log_px_given_zy_train = latent_gaussian_x_gaussian(z_train, z_mu_train, z_log_var_train, x_mu_train, x_log_var_train, sym_x, latent_size=nz, num_features=num_features, eq_samples=sym_eq_samples, iw_samples=sym_iw_samples) 470 | LL_eval, log_qz_given_xy_eval, log_pz_eval, log_px_given_zy_eval = latent_gaussian_x_gaussian(z_eval, z_mu_eval, z_log_var_eval, x_mu_eval, x_log_var_eval, sym_x, latent_size=nz, num_features=num_features, eq_samples=sym_eq_samples, iw_samples=sym_iw_samples) 471 | 472 | cost_gen = -LL_train 473 | cost = cost_gen + alpha*cost_cla 474 | 475 | # count parameters 476 | if distribution == 'bernoulli': 477 | params = lasagne.layers.get_all_params([classifier, l_dec_x_mu], trainable=True) 478 | for p in params: 479 | print p, p.get_value().shape 480 | params_count = lasagne.layers.count_params([classifier,l_dec_x_mu], trainable=True) 481 | elif distribution == 'gaussian': 482 | params = lasagne.layers.get_all_params([classifier,l_dec_x_mu, l_dec_x_log_var], trainable=True) 483 | for p in params: 484 | print p, p.get_value().shape 485 | params_count = lasagne.layers.count_params([classifier,l_dec_x_mu, l_dec_x_log_var], trainable=True) 486 | print 'Number of parameters:', params_count 487 | 488 | # gradients 489 | grads = T.grad(cost, params) 490 | 491 | ''' 492 | Straight Through Estimator 493 | forward pass: logits -> y=argmax -> f 494 | backward pass: f/y * p/theta 495 | ''' 496 | cla_params = lasagne.layers.get_all_params(classifier, trainable=True) 497 | grad_one_hot_y = T.grad(-LL_train, sym_l_in_y_train) 498 | cla_loss_gen = (grad_one_hot_y*lasagne.nonlinearities.softmax(predictions_train)).sum() 499 | cla_grads_gen = T.grad(cla_loss_gen,cla_params) 500 | 501 | for i in xrange(len(cla_grads_gen)): 502 | grads[i] += alpha_straight_through*cla_grads_gen[i] 503 | 504 | # mgrads = lasagne.updates.total_norm_constraint(grads,max_norm=max_norm) 505 | # cgrads = [T.clip(g, -clip_grad, clip_grad) for g in mgrads] 506 | 507 | # functions 508 | updates = lasagne.updates.adam(grads, params, beta1=0.9, beta2=0.999, epsilon=1e-8, learning_rate=sym_lr) 509 | 510 | if preprocesses_dataset is not None: 511 | train_model = theano.function([sym_index_l, sym_index_u, sym_batch_size_l, sym_batch_size_u, sym_lr, sym_eq_samples, sym_iw_samples], [LL_train, log_qz_given_xy_train, log_pz_train, log_px_given_zy_train, classifier_cost_train, accurracy_train_labeled], givens={sym_x_cla:T.concatenate([sh_x_train_labelled[batch_slice_l],sh_x_train_unlabelled[batch_slice_u]], axis=0), sym_x: T.concatenate([sh_x_train_labelled_preprocessed[batch_slice_l],sh_x_train_unlabelled_preprocessed[batch_slice_u]], axis=0), sym_y:sh_t_train_labelled[batch_slice_l]}, updates=updates) 512 | test_model = theano.function([sym_index, sym_batch_size, sym_eq_samples, sym_iw_samples], [LL_eval, log_qz_given_xy_eval, log_pz_eval, log_px_given_zy_eval, classifier_cost_eval, accurracy_eval], givens={sym_x_cla: sh_x_test[batch_slice], sym_x: sh_x_test_preprocessed[batch_slice], sym_y: sh_t_test[batch_slice]}) 513 | else: 514 | train_model = theano.function([sym_index_l, sym_index_u, sym_batch_size_l, sym_batch_size_u, sym_lr, sym_eq_samples, sym_iw_samples], [LL_train, log_qz_given_xy_train, log_pz_train, log_px_given_zy_train, classifier_cost_train,accurracy_train_labeled], givens={sym_x_cla:T.concatenate([sh_x_train_labelled[batch_slice_l],sh_x_train_unlabelled[batch_slice_u]], axis=0), sym_x: T.concatenate([sh_x_train_labelled[batch_slice_l],sh_x_train_unlabelled[batch_slice_u]], axis=0), sym_y: sh_t_train_labelled[batch_slice_l]}, updates=updates) 515 | test_model = theano.function([sym_index, sym_batch_size, sym_eq_samples, sym_iw_samples], [LL_eval, log_qz_given_xy_eval, log_pz_eval, log_px_given_zy_eval, classifier_cost_eval, accurracy_eval], givens={sym_x_cla: sh_x_test[batch_slice], sym_x: sh_x_test[batch_slice], sym_y: sh_t_test[batch_slice]}) 516 | 517 | # random generation for visualization 518 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 519 | srng_ran = RandomStreams(lasagne.random.get_rng().randint(1, 2147462579)) 520 | srng_ran_share = theano.tensor.shared_randomstreams.RandomStreams(1234) 521 | sym_ran_y = T.ivector('ran_y') 522 | 523 | ran_z = T.tile(srng_ran.normal((num_classes,nz)), (num_classes, 1)) 524 | if distribution == 'bernoulli': 525 | random_x_mean = lasagne.layers.get_output(l_dec_x_mu, {l_z:ran_z, l_in_y:to_one_hot(sym_ran_y, num_classes)}, deterministic=True) 526 | random_x = srng_ran_share.binomial(n=1, p=random_x_mean, dtype=theano.config.floatX) 527 | elif distribution == 'gaussian': 528 | random_x_mean, random_x_log_var = lasagne.layers.get_output([l_dec_x_mu, l_dec_x_log_var], {l_z:ran_z, l_in_y:to_one_hot(sym_ran_y, num_classes)}, deterministic=True) 529 | random_x = srng_ran_share.normal(avg=random_x_mean, std=T.exp(0.5*random_x_log_var)) 530 | generate = theano.function(inputs=[sym_ran_y], outputs=[random_x_mean, random_x]) 531 | 532 | 533 | ''' 534 | run 535 | ''' 536 | # Training and Testing functions 537 | def train_epoch(lr, eq_samples, iw_samples, batch_size): 538 | costs,log_qz_given_xy,log_pz,log_px_given_zy, loss, accurracy, accurracy_labeled = [],[],[],[],[],[],[] 539 | n_train_batches_labelled = x_labelled.shape[0] / num_labelled_per_batch 540 | n_train_batches_unlabelled = x_unlabelled.shape[0] / (batch_size - num_labelled_per_batch) 541 | 542 | for i in range(n_train_batches_unlabelled): 543 | costs_batch, log_qz_given_xy_batch,log_pz_batch,log_px_given_zy_batch, loss_batch, accurracy_labeled_batch = train_model(i % n_train_batches_labelled, i, num_labelled_per_batch, batch_size-num_labelled_per_batch, lr, eq_samples, iw_samples) 544 | costs += [costs_batch] 545 | log_qz_given_xy += [log_qz_given_xy_batch] 546 | log_pz += [log_pz_batch] 547 | log_px_given_zy += [log_px_given_zy_batch] 548 | loss += [loss_batch] 549 | accurracy_labeled += [accurracy_labeled_batch] 550 | return np.mean(costs), np.mean(log_qz_given_xy), np.mean(log_pz), np.mean(log_px_given_zy), np.mean(loss), np.mean(accurracy_labeled) 551 | 552 | def test_epoch(eq_samples, iw_samples, batch_size): 553 | n_test_batches = test_x.shape[0] / batch_size 554 | costs,log_qz_given_xy,log_pz,log_px_given_zy,loss, accurracy = [],[],[],[],[],[] 555 | for i in range(n_test_batches): 556 | costs_batch, log_qz_given_xy_batch,log_pz_batch,log_px_given_zy_batch, loss_batch, accurracy_batch = test_model(i, batch_size, eq_samples, iw_samples) 557 | costs += [costs_batch] 558 | log_qz_given_xy += [log_qz_given_xy_batch] 559 | log_pz += [log_pz_batch] 560 | log_px_given_zy += [log_px_given_zy_batch] 561 | loss += [loss_batch] 562 | accurracy += [accurracy_batch] 563 | return np.mean(costs), np.mean(log_qz_given_xy), np.mean(log_pz), np.mean(log_px_given_zy), np.mean(loss), np.mean(accurracy) 564 | 565 | 566 | print "Training" 567 | 568 | # TRAIN LOOP 569 | LL_train, log_qz_given_x_train, log_pz_train, log_px_given_z_train, loss_train, acc_labeled_train = [],[],[],[],[],[] 570 | LL_test, log_qz_given_x_test, log_pz_test, log_px_given_z_test, loss_test, acc_test = [],[],[],[],[],[] 571 | 572 | for epoch in range(1, 1+num_epochs): 573 | start = time.time() 574 | 575 | # randomly permute data and labels 576 | p_l = np.random.permutation(x_labelled.shape[0]) 577 | sh_x_train_labelled.set_value(x_labelled[p_l]) 578 | sh_t_train_labelled.set_value((y_labelled[p_l])) 579 | p_u = np.random.permutation(x_unlabelled.shape[0]) 580 | sh_x_train_unlabelled.set_value(x_unlabelled[p_u]) 581 | if preprocesses_dataset is not None: 582 | sh_x_train_labelled_preprocessed.set_value(preprocesses_dataset(x_labelled[p_l])) 583 | sh_x_train_unlabelled_preprocessed.set_value(preprocesses_dataset(x_unlabelled[p_u])) 584 | 585 | train_out = train_epoch(lr, eq_samples, iw_samples, batch_size) 586 | 587 | if np.isnan(train_out[0]): 588 | ValueError("NAN in train LL!") 589 | 590 | if epoch >= anneal_lr_epoch and epoch % every_anneal == 0: 591 | #annealing learning rate 592 | lr = lr*anneal_lr_factor 593 | 594 | if epoch % eval_epoch == 0: 595 | t = time.time() - start 596 | LL_train += [train_out[0]] 597 | log_qz_given_x_train += [train_out[1]] 598 | log_pz_train += [train_out[2]] 599 | log_px_given_z_train += [train_out[3]] 600 | loss_train +=[train_out[4]] 601 | acc_labeled_train += [train_out[5]] 602 | 603 | print "calculating LL eq=1, iw=1" 604 | test_out = test_epoch(eq_samples, iw_samples, batch_size=500) 605 | LL_test += [test_out[0]] 606 | log_qz_given_x_test += [test_out[1]] 607 | log_pz_test += [test_out[2]] 608 | log_px_given_z_test += [test_out[3]] 609 | loss_test += [test_out[4]] 610 | acc_test += [test_out[5]] 611 | 612 | line = "*Epoch=%d\tTime=%.2f\tLR=%.5f\n" %(epoch, t, lr) + \ 613 | " TRAIN:\tGen_loss=%.5f\tlogq(z|x)=%.5f\tlogp(z)=%.5f\tlogp(x|z)=%.5f\tdis_loss=%.5f\tlabel_error=%.5f\n" %(LL_train[-1], log_qz_given_x_train[-1], log_pz_train[-1], log_px_given_z_train[-1], loss_train[-1], 1-acc_labeled_train[-1]) + \ 614 | " EVAL-L1:\tGen_loss=%.5f\tlogq(z|x)=%.5f\tlogp(z)=%.5f\tlogp(x|z)=%.5f\tdis_loss=%.5f\terror=%.5f\n" %(LL_test[-1], log_qz_given_x_test[-1], log_pz_test[-1], log_px_given_z_test[-1], loss_test[-1], 1-acc_test[-1]) 615 | print line 616 | with open(logfile,'a') as f: 617 | f.write(line + "\n") 618 | 619 | # random generation for visualization 620 | if epoch % vis_epoch == 0: 621 | tail='-'+str(epoch)+'.png' 622 | ran_y = np.int32(np.repeat(np.arange(num_classes), num_classes)) 623 | _x_mean, _x = generate(ran_y) 624 | _x_mean = _x_mean.reshape((num_generation,-1)) 625 | _x = _x.reshape((num_generation,-1)) 626 | image = paramgraphics.mat_to_img(_x_mean.T, dim_input, colorImg=colorImg, scale=generation_scale, 627 | save_path=os.path.join(res_out, 'mean'+tail)) 628 | 629 | #save model 630 | model_out = os.path.join(res_out, 'model') 631 | if epoch % (vis_epoch*10) == 0: 632 | if distribution == 'bernoulli': 633 | all_params=lasagne.layers.get_all_params([classifier, l_dec_x_mu]) 634 | elif distribution == 'gaussian': 635 | all_params=lasagne.layers.get_all_params([classifier, l_dec_x_mu, l_dec_x_log_var]) 636 | f = gzip.open(model_out + 'epoch%i'%(epoch), 'wb') 637 | cPickle.dump(all_params, f, protocol=cPickle.HIGHEST_PROTOCOL) 638 | f.close() --------------------------------------------------------------------------------