├── README.md ├── cifar10_data.py ├── components ├── __init__.py ├── objectives.py └── shortcuts.py ├── generate_given_models.py ├── images ├── class-0.png ├── class-1.png ├── class-2.png ├── class-7.png ├── mnist_fm.png ├── mnist_random.png ├── svhn_data.png ├── svhn_linear.png └── svhn_share_z.png ├── layers ├── __init__.py ├── deconv.py ├── merge.py └── sample.py ├── nn.py ├── svhn_data.py ├── utils ├── __init__.py ├── checkpoints.py ├── create_ssl_data.py ├── others.py └── paramgraphics.py ├── x2y_yz2x_xy2p_ssl_cifar10.py ├── x2y_yz2x_xy2p_ssl_mnist.py ├── x2y_yz2x_xy2p_ssl_svhn.py └── zca_bn.py /README.md: -------------------------------------------------------------------------------- 1 | # Triple Generative Adversarial Nets (Triple-GAN) 2 | ## [Chongxuan Li](https://github.com/zhenxuan00), [Kun Xu](https://github.com/taufikxu), Jun Zhu and Bo Zhang 3 | 4 | Code for reproducing most of the results in the [paper](https://arxiv.org/abs/1703.02291). Triple-GAN: a unified GAN model for classification and class-conditional generation in semi-supervised learning. 5 | 6 | Warning: the code is still under development. 7 | 8 | 9 | ## Triple-GAN-V2 and code in Pytorch! 10 | 11 | We propose Triple-GAN-V2 built upon mean teacher classifier and projection discriminator with spectral norm and implement Triple-GAN in Pytorch. See the source code at https://github.com/taufikxu/Triple-GAN 12 | 13 | 14 | ## Envoronment settings and libs we used in our experiments 15 | 16 | This project is tested under the following environment setting. 17 | - OS: Ubuntu 16.04.3 18 | - GPU: Geforce 1080 Ti or Titan X(Pascal or Maxwell) 19 | - Cuda: 8.0, Cudnn: v5.1 or v7.03 20 | - Python: 2.7.14(setup with Miniconda2) 21 | - Theano: 0.9.0.dev-c697eeab84e5b8a74908da654b66ec9eca4f1291 22 | - Lasagne: 0.2.dev1 23 | - Parmesan: 0.1.dev1 24 | 25 | > Python 26 | > Numpy 27 | > Scipy 28 | > [Theano](https://github.com/Theano/Theano) 29 | > [Lasagne](https://github.com/Lasagne/Lasagne)(version 0.2.dev1) 30 | > [Parmesan](https://github.com/casperkaae/parmesan) 31 | 32 | Thank the authors of these libs. We also thank the authors of [Improved-GAN](https://github.com/openai/improved-gan) and [Temporal Ensemble](https://github.com/smlaine2/tempens) for providing their code. Our code is widely adapted from their repositories. 33 | 34 | ## Results 35 | 36 | Triple-GAN can achieve excellent classification results on MNIST, SVHN and CIFAR10 datasets, see the [paper](https://arxiv.org/abs/1703.02291) for a comparison with the previous state-of-the-art. See generated images as follows: 37 | 38 | ### Comparing Triple-GAN (right) with GAN trained with [feature matching](https://arxiv.org/abs/1606.03498) (left) 39 | 40 | 41 | ### Generating images in four specific classes (airplane, automobile, bird, horse) 42 | 43 | 44 | 45 | ### Disentangling styles from classes (left: data, right: Triple-GAN) 46 | 47 | 48 | ### Class-conditional linear interpolation on latent space 49 | 50 | -------------------------------------------------------------------------------- /cifar10_data.py: -------------------------------------------------------------------------------- 1 | import cPickle 2 | import os 3 | import sys 4 | import tarfile 5 | from six.moves import urllib 6 | import numpy as np 7 | 8 | def maybe_download_and_extract(data_dir, url='http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'): 9 | if not os.path.exists(os.path.join(data_dir, 'cifar-10-batches-py')): 10 | if not os.path.exists(data_dir): 11 | os.makedirs(data_dir) 12 | filename = url.split('/')[-1] 13 | filepath = os.path.join(data_dir, filename) 14 | if not os.path.exists(filepath): 15 | def _progress(count, block_size, total_size): 16 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, 17 | float(count * block_size) / float(total_size) * 100.0)) 18 | sys.stdout.flush() 19 | filepath, _ = urllib.request.urlretrieve(url, filepath, _progress) 20 | print() 21 | statinfo = os.stat(filepath) 22 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 23 | tarfile.open(filepath, 'r:gz').extractall(data_dir) 24 | 25 | def unpickle(file): 26 | fo = open(file, 'rb') 27 | d = cPickle.load(fo) 28 | fo.close() 29 | return {'x': np.cast[np.float32]((-127.5 + d['data'].reshape((10000,3,32,32)))/128.), 'y': np.array(d['labels']).astype(np.int32)} 30 | 31 | def load(data_dir, subset='train'): 32 | maybe_download_and_extract(data_dir) 33 | if subset=='train': 34 | train_data = [unpickle(os.path.join(data_dir,'cifar-10-batches-py/data_batch_' + str(i))) for i in range(1,6)] 35 | trainx = np.concatenate([d['x'] for d in train_data],axis=0) 36 | trainy = np.concatenate([d['y'] for d in train_data],axis=0) 37 | return trainx, trainy 38 | elif subset=='test': 39 | test_data = unpickle(os.path.join(data_dir,'cifar-10-batches-py/test_batch')) 40 | testx = test_data['x'] 41 | testy = test_data['y'] 42 | return testx, testy 43 | else: 44 | raise NotImplementedError('subset should be either train or test') 45 | -------------------------------------------------------------------------------- /components/__init__.py: -------------------------------------------------------------------------------- 1 | from .objectives import * 2 | from .shortcuts import * -------------------------------------------------------------------------------- /components/objectives.py: -------------------------------------------------------------------------------- 1 | ''' 2 | objectives 3 | ''' 4 | import numpy as np 5 | import theano.tensor as T 6 | import theano 7 | import lasagne 8 | from theano.tensor.extra_ops import to_one_hot 9 | 10 | def categorical_crossentropy(predictions, targets, epsilon=1e-6): 11 | # avoid overflow 12 | predictions = T.clip(predictions, epsilon, 1-epsilon) 13 | # check shape of targets 14 | num_cls = predictions.shape[1] 15 | if targets.ndim == predictions.ndim - 1: 16 | targets = theano.tensor.extra_ops.to_one_hot(targets, num_cls) 17 | elif targets.ndim != predictions.ndim: 18 | raise TypeError('rank mismatch between targets and predictions') 19 | return lasagne.objectives.categorical_crossentropy(predictions, targets).mean() 20 | 21 | def entropy(predictions): 22 | return categorical_crossentropy(predictions, predictions) 23 | 24 | def negative_entropy_of_mean(predictions): 25 | return -entropy(predictions.mean(axis=0, keepdims=True)) 26 | 27 | def categorical_crossentropy_of_mean(predictions): 28 | num_cls = predictions.shape[1] 29 | uniform_targets = T.ones((1, num_cls)) / num_cls 30 | return categorical_crossentropy(predictions.mean(axis=0, keepdims=True), uniform_targets) 31 | 32 | def categorical_crossentropy_ssl_alternative(predictions, targets, num_labelled, weight_decay, alpha_labeled=1., alpha_unlabeled=.3, alpha_average=1e-3, alpha_decay=1e-4): 33 | ce_loss = categorical_crossentropy(predictions[:num_labelled], targets) 34 | en_loss = entropy(predictions[num_labelled:]) 35 | av_loss = negative_entropy_of_mean(predictions[num_labelled:]) 36 | return alpha_labeled*ce_loss + alpha_unlabeled*en_loss + alpha_average*av_loss + alpha_decay*weight_decay 37 | 38 | def categorical_crossentropy_ssl(predictions, targets, num_labelled, weight_decay, alpha_labeled=1., alpha_unlabeled=.3, alpha_average=1e-3, alpha_decay=1e-4): 39 | ce_loss = categorical_crossentropy(predictions[:num_labelled], targets) 40 | en_loss = entropy(predictions[num_labelled:]) 41 | av_loss = categorical_crossentropy_of_mean(predictions[num_labelled:]) 42 | return alpha_labeled*ce_loss + alpha_unlabeled*en_loss + alpha_average*av_loss + alpha_decay*weight_decay 43 | 44 | def categorical_crossentropy_ssl_separated(predictions_l, targets, predictions_u, weight_decay, alpha_labeled=1., alpha_unlabeled=.3, alpha_average=1e-3, alpha_decay=1e-4): 45 | ce_loss = categorical_crossentropy(predictions_l, targets) 46 | en_loss = entropy(predictions_u) 47 | av_loss = categorical_crossentropy_of_mean(predictions_u) 48 | return alpha_labeled*ce_loss + alpha_unlabeled*en_loss + alpha_average*av_loss + alpha_decay*weight_decay 49 | 50 | def maximum_mean_discripancy(sample, data, sigma=[2. , 5., 10., 20., 40., 80.]): 51 | sample = sample.flatten(2) 52 | data = data.flatten(2) 53 | 54 | x = T.concatenate([sample, data], axis=0) 55 | xx = T.dot(x, x.T) 56 | x2 = T.sum(x*x, axis=1, keepdims=True) 57 | exponent = xx - .5*x2 - .5*x2.T 58 | s_samples = T.ones([sample.shape[0], 1])*1./sample.shape[0] 59 | s_data = -T.ones([data.shape[0], 1])*1./data.shape[0] 60 | s_all = T.concatenate([s_samples, s_data], axis=0) 61 | s_mat = T.dot(s_all, s_all.T) 62 | mmd_loss = 0. 63 | for s in sigma: 64 | kernel_val = T.exp((1./s) * exponent) 65 | mmd_loss += T.sum(s_mat*kernel_val) 66 | return T.sqrt(mmd_loss) 67 | 68 | def feature_matching(f_sample, f_data, norm='l2'): 69 | if norm == 'l2': 70 | return T.mean(T.square(T.mean(f_sample,axis=0)-T.mean(f_data,axis=0))) 71 | elif norm == 'l1': 72 | return T.mean(abs(T.mean(f_sample,axis=0)-T.mean(f_data,axis=0))) 73 | else: 74 | raise NotImplementedError 75 | 76 | def multiclass_s3vm_loss(predictions_l, targets, predictions_u, weight_decay, alpha_labeled=1., alpha_unlabeled=1., alpha_average=1., alpha_decay=1., delta=1., norm_type=2, form ='mean_class', entropy_term=False): 77 | ''' 78 | predictions: 79 | size L x nc 80 | U x nc 81 | targets: 82 | size L x nc 83 | 84 | output: 85 | weighted sum of hinge loss, hat loss, balance constraint and weight decay 86 | ''' 87 | num_cls = predictions_l.shape[1] 88 | if targets.ndim == predictions_l.ndim - 1: 89 | targets = theano.tensor.extra_ops.to_one_hot(targets, num_cls) 90 | elif targets.ndim != predictions_l.ndim: 91 | raise TypeError('rank mismatch between targets and predictions') 92 | 93 | hinge_loss = multiclass_hinge_loss_(predictions_l, targets, delta) 94 | hat_loss = multiclass_hat_loss(predictions_u, delta) 95 | regularization = balance_constraint(predictions_l, targets, predictions_u, norm_type, form) 96 | if not entropy_term: 97 | return alpha_labeled*hinge_loss.mean() + alpha_unlabeled*hat_loss.mean() + alpha_average*regularization + alpha_decay*weight_decay 98 | else: 99 | # given an unlabeled data, when treat hat loss as the entropy term derived from a lowerbound, it should conflict to current prediction, which is quite strange but true ... the entropy term enforce the discriminator to predict unlabeled data uniformly as a regularization 100 | # max entropy regularization provides a tighter lowerbound but hurt the semi-supervised learning performance as it conflicts to the hat loss ... 101 | return alpha_labeled*hinge_loss.mean() - alpha_unlabeled*hat_loss.mean() + alpha_average*regularization + alpha_decay*weight_decay 102 | 103 | 104 | def multiclass_hinge_loss_(predictions, targets, delta=1): 105 | return lasagne.objectives.multiclass_hinge_loss(predictions, targets, delta) 106 | 107 | def multiclass_hinge_loss(predictions, targets, weight_decay, alpha_decay=1., delta=1): 108 | return multiclass_hinge_loss_(predictions, targets, delta).mean() + alpha_decay*weight_decay 109 | 110 | def multiclass_hat_loss(predictions, delta=1): 111 | targets = T.argmax(predictions, axis=1) 112 | return multiclass_hinge_loss(predictions, targets, delta) 113 | 114 | def balance_constraint(p_l, t_l, p_u, norm_type=2, form='mean_class'): 115 | ''' 116 | balance constraint 117 | ------ 118 | norm_type: type of norm 119 | l2 or l1 120 | form: form of regularization 121 | mean_class: average mean activation of u and l data should be the same over each class 122 | mean_all: average mean activation of u and l data should be the same over all data 123 | ratio: 124 | 125 | ''' 126 | t_u = T.argmax(p_u, axis=1) 127 | num_cls = p_l.shape[1] 128 | t_u = theano.tensor.extra_ops.to_one_hot(t_u, num_cls) 129 | if form == 'mean_class': 130 | res = (p_l*t_l).mean(axis=0) - (p_u*t_u).mean(axis=0) 131 | elif form == 'mean_all': 132 | res = p_l.mean(axis=0) - p_u.mean(axis=0) 133 | elif form == 'ratio': 134 | pass 135 | 136 | # res should be a vector with length number_class 137 | return res.norm(norm_type) 138 | -------------------------------------------------------------------------------- /components/shortcuts.py: -------------------------------------------------------------------------------- 1 | ''' 2 | shortcuts for compsited layers 3 | ''' 4 | import numpy as np 5 | import theano.tensor as T 6 | import theano 7 | import lasagne 8 | import sys 9 | sys.path.append("..") 10 | from layers.merge import ConvConcatLayer, MLPConcatLayer 11 | 12 | # convolutional layer 13 | # following optional batch normalization, pooling and dropout 14 | def convlayer(l,bn,dr,ps,n_kerns,d_kerns,nonlinearity,pad,stride,W=lasagne.init.GlorotUniform(),b=lasagne.init.Constant(0.), name=None): 15 | l = lasagne.layers.Conv2DLayer(l, num_filters=n_kerns, filter_size=d_kerns, stride=stride, pad=pad, W=W, b=b, nonlinearity=nonlinearity, name=name) 16 | if bn: 17 | l = lasagne.layers.batch_norm(l) 18 | if ps > 1: 19 | l = lasagne.layers.MaxPool2DLayer(l, pool_size=(ps,ps)) 20 | if dr > 0: 21 | l = lasagne.layers.DropoutLayer(l, p=dr) 22 | return l 23 | 24 | # mlp layer 25 | # following optional batch normalization and dropout 26 | def mlplayer(l,bn,dr,num_units,nonlinearity,name): 27 | l = lasagne.layers.DenseLayer(l,num_units=num_units,nonlinearity=nonlinearity,name="MLP-"+name) 28 | if bn: 29 | l = lasagne.layers.batch_norm(l, name="BN-"+name) 30 | if dr > 0: 31 | l = lasagne.layers.DropoutLayer(l, p=dr, name="Drop-"+name) 32 | return l 33 | -------------------------------------------------------------------------------- /generate_given_models.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code generates data in various ways given a trained Triple-GAN 3 | 4 | Note: Due to the effect of Batch Normalization, it is better to generate batch_size_g (see train file, 200 for cifiar10) samples distributed equally across class in each batch. 5 | ''' 6 | import gzip, os, cPickle, time, math, argparse, shutil, sys 7 | 8 | import numpy as np 9 | import theano, lasagne 10 | import theano.tensor as T 11 | import lasagne.layers as ll 12 | import lasagne.nonlinearities as ln 13 | from lasagne.layers import dnn 14 | import nn 15 | from lasagne.init import Normal 16 | from theano.sandbox.rng_mrg import MRG_RandomStreams 17 | 18 | from layers.merge import ConvConcatLayer, MLPConcatLayer 19 | from layers.deconv import Deconv2DLayer 20 | 21 | from components.shortcuts import convlayer, mlplayer 22 | from components.objectives import categorical_crossentropy_ssl_separated, maximum_mean_discripancy, categorical_crossentropy, feature_matching 23 | from utils.create_ssl_data import create_ssl_data, create_ssl_data_subset 24 | from utils.others import get_nonlin_list, get_pad_list, bernoullisample, printarray_2D, array2file_2D 25 | import utils.paramgraphics as paramgraphics 26 | from utils.checkpoints import load_weights 27 | 28 | # global 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("-oldmodel", type=str, default=argparse.SUPPRESS) 31 | parser.add_argument("-dataset", type=str, default='svhn') 32 | args = parser.parse_args() 33 | 34 | filename_script=os.path.basename(os.path.realpath(__file__)) 35 | outfolder=os.path.join("results-ssl", os.path.splitext(filename_script)[0]) 36 | outfolder+='.' 37 | outfolder+=args.dataset 38 | outfolder+='.' 39 | outfolder+=str(int(time.time())) 40 | if not os.path.exists(outfolder): 41 | os.makedirs(outfolder) 42 | shutil.copy(os.path.realpath(__file__), os.path.join(outfolder, filename_script)) 43 | 44 | # seeds 45 | seed=1234 46 | rng=np.random.RandomState(seed) 47 | theano_rng=MRG_RandomStreams(rng.randint(2 ** 15)) 48 | lasagne.random.set_rng(np.random.RandomState(rng.randint(2 ** 15))) 49 | 50 | # G 51 | n_z=100 52 | batch_size_g=200 53 | num_x=50000 54 | # data dependent 55 | if args.dataset == 'svhn' or args.dataset == 'cifar10': 56 | gen_final_non=ln.tanh 57 | num_classes=10 58 | dim_input=(32,32) 59 | in_channels=3 60 | colorImg=True 61 | generation_scale=True 62 | elif args.dataset == 'mnist': 63 | gen_final_non=ln.sigmoid 64 | num_classes=10 65 | dim_input=(28,28) 66 | in_channels=1 67 | colorImg=False 68 | generation_scale=False 69 | 70 | ''' 71 | models 72 | ''' 73 | # symbols 74 | sym_y_g = T.ivector() 75 | sym_z_input = T.matrix() 76 | sym_z_rand = theano_rng.uniform(size=(batch_size_g, n_z)) 77 | sym_z_shared = T.tile(theano_rng.uniform((batch_size_g/num_classes, n_z)), (num_classes, 1)) 78 | 79 | # generator y2x: p_g(x, y) = p(y) p_g(x | y) where x = G(z, y), z follows p_g(z) 80 | gen_in_z = ll.InputLayer(shape=(None, n_z)) 81 | gen_in_y = ll.InputLayer(shape=(None,)) 82 | gen_layers = [gen_in_z] 83 | if args.dataset == 'svhn' or args.dataset == 'cifar10': 84 | gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-00')) 85 | gen_layers.append(nn.batch_norm(ll.DenseLayer(gen_layers[-1], num_units=4*4*512, W=Normal(0.05), nonlinearity=nn.relu, name='gen-01'), g=None, name='gen-02')) 86 | gen_layers.append(ll.ReshapeLayer(gen_layers[-1], (-1,512,4,4), name='gen-03')) 87 | gen_layers.append(ConvConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-10')) 88 | gen_layers.append(nn.batch_norm(nn.Deconv2DLayer(gen_layers[-1], (None,256,8,8), (5,5), W=Normal(0.05), nonlinearity=nn.relu, name='gen-11'), g=None, name='gen-12')) # 4 -> 8 89 | gen_layers.append(ConvConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-20')) 90 | gen_layers.append(nn.batch_norm(nn.Deconv2DLayer(gen_layers[-1], (None,128,16,16), (5,5), W=Normal(0.05), nonlinearity=nn.relu, name='gen-21'), g=None, name='gen-22')) # 8 -> 16 91 | gen_layers.append(ConvConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-30')) 92 | gen_layers.append(nn.weight_norm(nn.Deconv2DLayer(gen_layers[-1], (None,3,32,32), (5,5), W=Normal(0.05), nonlinearity=gen_final_non, name='gen-31'), train_g=True, init_stdv=0.1, name='gen-32')) # 16 -> 32 93 | elif args.dataset == 'mnist': 94 | gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-1')) 95 | gen_layers.append(ll.batch_norm(ll.DenseLayer(gen_layers[-1], num_units=500, nonlinearity=ln.softplus, name='gen-2'), name='gen-3')) 96 | gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-4')) 97 | gen_layers.append(ll.batch_norm(ll.DenseLayer(gen_layers[-1], num_units=500, nonlinearity=ln.softplus, name='gen-5'), name='gen-6')) 98 | gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-7')) 99 | gen_layers.append(nn.l2normalize(ll.DenseLayer(gen_layers[-1], num_units=28**2, nonlinearity=gen_final_non, name='gen-8'))) 100 | 101 | # outputs 102 | gen_out_x = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_rand}, deterministic=False) 103 | gen_out_x_shared = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_shared}, deterministic=False) 104 | gen_out_x_interpolation = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_input}, deterministic=False) 105 | generate = theano.function(inputs=[sym_y_g], outputs=gen_out_x) 106 | generate_shared = theano.function(inputs=[sym_y_g], outputs=gen_out_x_shared) 107 | generate_interpolation = theano.function(inputs=[sym_y_g, sym_z_input], outputs=gen_out_x_interpolation) 108 | 109 | ''' 110 | Load pretrained model 111 | ''' 112 | load_weights(args.oldmodel, gen_layers) 113 | 114 | # interpolation on latent space (z) class conditionally 115 | for i in xrange(10): 116 | sample_y = np.int32(np.repeat(np.arange(num_classes), batch_size_g/num_classes)) 117 | orignial_z = np.repeat(rng.uniform(size=(num_classes,n_z)), batch_size_g/num_classes, axis=0) 118 | target_z = np.repeat(rng.uniform(size=(num_classes,n_z)), batch_size_g/num_classes, axis=0) 119 | alpha = np.tile(np.arange(batch_size_g/num_classes) * 1.0 / (batch_size_g/num_classes-1), num_classes) 120 | alpha = alpha.reshape(-1,1) 121 | z = np.float32((1-alpha)*orignial_z+alpha*target_z) 122 | x_gen_batch = generate_interpolation(sample_y, z) 123 | x_gen_batch = x_gen_batch.reshape((batch_size_g,-1)) 124 | image = paramgraphics.mat_to_img(x_gen_batch.T, dim_input, colorImg=colorImg, tile_shape=(num_classes, 2*num_classes), scale=generation_scale, save_path=os.path.join(outfolder, 'interpolation-'+str(i)+'.png')) 125 | 126 | # class conditionally generation with shared z and fixed y 127 | for i in xrange(10): 128 | sample_y = np.int32(np.repeat(np.arange(num_classes), batch_size_g/num_classes)) 129 | x_gen_batch = generate_shared(sample_y) 130 | x_gen_batch = x_gen_batch.reshape((batch_size_g,-1)) 131 | image = paramgraphics.mat_to_img(x_gen_batch.T, dim_input, colorImg=colorImg, tile_shape=(num_classes, 2*num_classes), scale=generation_scale, save_path=os.path.join(outfolder, 'shared-'+str(i)+'.png')) 132 | 133 | # generation with randomly sampled z and y 134 | for i in xrange(10): 135 | sample_y = np.int32(np.repeat(np.arange(num_classes), batch_size_g/num_classes)) 136 | inds = np.random.permutation(batch_size_g) 137 | sample_y = sample_y[inds] 138 | x_gen_batch = generate(sample_y) 139 | x_gen_batch = x_gen_batch.reshape((batch_size_g,-1)) 140 | image = paramgraphics.mat_to_img(x_gen_batch.T, dim_input, colorImg=colorImg, tile_shape=(num_classes, 2*num_classes), scale=generation_scale, save_path=os.path.join(outfolder, 'random-'+str(i)+'.png')) 141 | 142 | if args.dataset != 'cifar10': 143 | exit() 144 | 145 | # large number of random generation for inception score computation 146 | x_gen = [] 147 | # generation for each class 148 | x_classes = [] 149 | for i in xrange(num_classes): 150 | x_classes.append([]) 151 | for i in xrange(num_x / batch_size_g): 152 | print i 153 | sample_y = np.int32(np.repeat(np.arange(num_classes), batch_size_g/num_classes)) 154 | x_gen_batch = generate(sample_y) 155 | x_gen.append(x_gen_batch) 156 | if i < 5: 157 | for j in xrange(num_classes): 158 | x_classes[j].append(x_gen_batch[j*20:(j+1)*20]) 159 | if i == 5: 160 | for ind in xrange(num_classes): 161 | x_classes[ind] = np.concatenate(x_classes[ind], axis=0) 162 | image = paramgraphics.mat_to_img(x_classes[ind].T, dim_input, colorImg=colorImg, tile_shape=(num_classes, num_classes), scale=generation_scale, save_path=os.path.join(outfolder, 'class-'+str(ind)+'.png')) 163 | 164 | x_gen=np.concatenate(x_gen, axis=0) 165 | np.save(os.path.join(outfolder,'inception_score'), x_gen) 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /images/class-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenxuan00/triple-gan/9691c007eab21322657036f81ff93ef5634652fa/images/class-0.png -------------------------------------------------------------------------------- /images/class-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenxuan00/triple-gan/9691c007eab21322657036f81ff93ef5634652fa/images/class-1.png -------------------------------------------------------------------------------- /images/class-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenxuan00/triple-gan/9691c007eab21322657036f81ff93ef5634652fa/images/class-2.png -------------------------------------------------------------------------------- /images/class-7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenxuan00/triple-gan/9691c007eab21322657036f81ff93ef5634652fa/images/class-7.png -------------------------------------------------------------------------------- /images/mnist_fm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenxuan00/triple-gan/9691c007eab21322657036f81ff93ef5634652fa/images/mnist_fm.png -------------------------------------------------------------------------------- /images/mnist_random.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenxuan00/triple-gan/9691c007eab21322657036f81ff93ef5634652fa/images/mnist_random.png -------------------------------------------------------------------------------- /images/svhn_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenxuan00/triple-gan/9691c007eab21322657036f81ff93ef5634652fa/images/svhn_data.png -------------------------------------------------------------------------------- /images/svhn_linear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenxuan00/triple-gan/9691c007eab21322657036f81ff93ef5634652fa/images/svhn_linear.png -------------------------------------------------------------------------------- /images/svhn_share_z.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenxuan00/triple-gan/9691c007eab21322657036f81ff93ef5634652fa/images/svhn_share_z.png -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .merge import * 2 | from .sample import * 3 | from .deconv import * -------------------------------------------------------------------------------- /layers/deconv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import theano as th 3 | import theano.tensor as T 4 | import lasagne 5 | from lasagne.layers import dnn 6 | 7 | # code from ImprovedGAN, Salimans and Goodfellow: https://github.com/openai/improved-gan 8 | 9 | class Deconv2DLayer(lasagne.layers.Layer): 10 | def __init__(self, incoming, target_shape, filter_size, stride=(2, 2), 11 | W=lasagne.init.Normal(0.05), b=lasagne.init.Constant(0.), nonlinearity=None, **kwargs): 12 | super(Deconv2DLayer, self).__init__(incoming, **kwargs) 13 | self.target_shape = target_shape 14 | self.nonlinearity = (lasagne.nonlinearities.identity if nonlinearity is None else nonlinearity) 15 | self.filter_size = lasagne.layers.dnn.as_tuple(filter_size, 2) 16 | self.stride = lasagne.layers.dnn.as_tuple(stride, 2) 17 | self.target_shape = target_shape 18 | 19 | self.W_shape = (incoming.output_shape[1], target_shape[1], filter_size[0], filter_size[1]) 20 | self.W = self.add_param(W, self.W_shape, name="W") 21 | if b is not None: 22 | self.b = self.add_param(b, (target_shape[1],), name="b") 23 | else: 24 | self.b = None 25 | 26 | def get_output_for(self, input, **kwargs): 27 | op = T.nnet.abstract_conv.AbstractConv2d_gradInputs(imshp=self.target_shape, kshp=self.W_shape, subsample=self.stride, border_mode='half') 28 | activation = op(self.W, input, self.target_shape[2:]) 29 | 30 | if self.b is not None: 31 | activation += self.b.dimshuffle('x', 0, 'x', 'x') 32 | 33 | return self.nonlinearity(activation) 34 | 35 | def get_output_shape_for(self, input_shape): 36 | return self.target_shape -------------------------------------------------------------------------------- /layers/merge.py: -------------------------------------------------------------------------------- 1 | import lasagne 2 | from lasagne import init 3 | from lasagne import nonlinearities 4 | 5 | import theano.tensor as T 6 | import theano 7 | import numpy as np 8 | import theano.tensor.extra_ops as Textra 9 | 10 | __all__ = [ 11 | "ConvConcatLayer", # 12 | "MLPConcatLayer", # 13 | ] 14 | 15 | 16 | class ConvConcatLayer(lasagne.layers.MergeLayer): 17 | ''' 18 | concatenate a tensor and a vector on feature map axis 19 | ''' 20 | def __init__(self, incomings, num_cls, **kwargs): 21 | super(ConvConcatLayer, self).__init__(incomings, **kwargs) 22 | self.num_cls = num_cls 23 | 24 | def get_output_shape_for(self, input_shapes): 25 | res = list(input_shapes[0]) 26 | res[1] += self.num_cls 27 | return tuple(res) 28 | 29 | def get_output_for(self, input, **kwargs): 30 | x, y = input 31 | if y.ndim == 1: 32 | y = T.extra_ops.to_one_hot(y, self.num_cls) 33 | if y.ndim == 2: 34 | y = y.dimshuffle(0, 1, 'x', 'x') 35 | assert y.ndim == 4 36 | return T.concatenate([x, y*T.ones((x.shape[0], y.shape[1], x.shape[2], x.shape[3]))], axis=1) 37 | 38 | class MLPConcatLayer(lasagne.layers.MergeLayer): 39 | ''' 40 | concatenate a matrix and a vector on feature axis 41 | ''' 42 | def __init__(self, incomings, num_cls, **kwargs): 43 | super(MLPConcatLayer, self).__init__(incomings, **kwargs) 44 | self.num_cls = num_cls 45 | 46 | def get_output_shape_for(self, input_shapes): 47 | res = list(input_shapes[0]) 48 | res[1] += self.num_cls 49 | return tuple(res) 50 | 51 | def get_output_for(self, input, **kwargs): 52 | x, y = input 53 | if y.ndim == 1: 54 | y = T.extra_ops.to_one_hot(y, self.num_cls) 55 | assert y.ndim == 2 56 | return T.concatenate([x, y], axis=1) -------------------------------------------------------------------------------- /layers/sample.py: -------------------------------------------------------------------------------- 1 | import lasagne 2 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 3 | import theano.tensor as T 4 | import theano 5 | 6 | 7 | class GaussianSampleLayer(lasagne.layers.MergeLayer): 8 | """ 9 | Simple sampling layer drawing a single Monte Carlo sample to approximate 10 | E_q [log( p(x,z) / q(z|x) )]. This is the approach described in [KINGMA]_. 11 | Parameters 12 | ---------- 13 | mu, log_var : :class:`Layer` instances 14 | Parameterizing the mean and log(variance) of the distribution to sample 15 | from as described in [KINGMA]_. The code assumes that these have the 16 | same number of dimensions. 17 | seed : int 18 | seed to random stream 19 | Methods 20 | ---------- 21 | seed : Helper function to change the random seed after init is called 22 | References 23 | ---------- 24 | .. [KINGMA] Kingma, Diederik P., and Max Welling. 25 | "Auto-Encoding Variational Bayes." 26 | arXiv preprint arXiv:1312.6114 (2013). 27 | """ 28 | def __init__(self, mean, log_var, 29 | seed=None, 30 | **kwargs): 31 | super(SimpleSampleLayer, self).__init__([mean, log_var], **kwargs) 32 | 33 | if seed is None: 34 | seed = lasagne.random.get_rng().randint(1, 2147462579) 35 | self._srng = RandomStreams(seed) 36 | 37 | def seed(self, seed=None): 38 | if seed is None: 39 | seed = lasagne.random.get_rng().randint(1, 2147462579) 40 | self._srng.seed(seed) 41 | 42 | def get_output_shape_for(self, input_shapes): 43 | return input_shapes[0] 44 | 45 | def get_output_for(self, input, **kwargs): 46 | mu, log_var = input 47 | eps = self._srng.normal(mu.shape) 48 | z = mu + T.exp(0.5 * log_var) * eps 49 | return z 50 | 51 | class BernoulliSampleLayer(lasagne.layers.Layer): 52 | """ 53 | Simple sampling layer drawing samples from bernoulli distributions. 54 | Parameters 55 | ---------- 56 | mean : :class:`Layer` instances 57 | Parameterizing the mean value of each bernoulli distribution 58 | seed : int 59 | seed to random stream 60 | Methods 61 | ---------- 62 | seed : Helper function to change the random seed after init is called 63 | """ 64 | 65 | def __init__(self, mean, 66 | seed=None, 67 | **kwargs): 68 | super(SimpleBernoulliSampleLayer, self).__init__(mean, **kwargs) 69 | 70 | if seed is None: 71 | seed = lasagne.random.get_rng().randint(1, 2147462579) 72 | 73 | self._srng = RandomStreams(seed) 74 | 75 | def seed(self, seed=lasagne.random.get_rng().randint(1, 2147462579)): 76 | self._srng.seed(seed) 77 | 78 | def get_output_shape_for(self, input_shape): 79 | return input_shape 80 | 81 | def get_output_for(self, mu, **kwargs): 82 | return self._srng.binomial(size=mu.shape, p=mu, dtype=mu.dtype) -------------------------------------------------------------------------------- /nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | neural network stuff, intended to be used with Lasagne 3 | """ 4 | 5 | import numpy as np 6 | import theano as th 7 | import theano.tensor as T 8 | import lasagne 9 | from lasagne.layers import dnn 10 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 11 | 12 | # T.nnet.relu has some stability issues, this is better 13 | def relu(x): 14 | return T.maximum(x, 0) 15 | 16 | def lrelu(x, a=0.2): 17 | return T.maximum(x, a*x) 18 | 19 | def centered_softplus(x): 20 | return T.nnet.softplus(x) - np.cast[th.config.floatX](np.log(2.)) 21 | 22 | def log_sum_exp(x, axis=1): 23 | m = T.max(x, axis=axis) 24 | return m+T.log(T.sum(T.exp(x-m.dimshuffle(0,'x')), axis=axis)) 25 | 26 | def adam_updates(params, cost, lr=0.001, mom1=0.9, mom2=0.999): 27 | updates = [] 28 | grads = T.grad(cost, params) 29 | t = th.shared(np.cast[th.config.floatX](1.)) 30 | for p, g in zip(params, grads): 31 | value = p.get_value(borrow=True) 32 | v = th.shared(np.zeros(value.shape, dtype=value.dtype), 33 | broadcastable=p.broadcastable) 34 | mg = th.shared(np.zeros(value.shape, dtype=value.dtype), 35 | broadcastable=p.broadcastable) 36 | 37 | v_t = mom1*v + (1. - mom1)*g 38 | mg_t = mom2*mg + (1. - mom2)*T.square(g) 39 | v_hat = v_t / (1. - mom1 ** t) 40 | mg_hat = mg_t / (1. - mom2 ** t) 41 | g_t = v_hat / T.sqrt(mg_hat + 1e-8) 42 | p_t = p - lr * g_t 43 | updates.append((v, v_t)) 44 | updates.append((mg, mg_t)) 45 | updates.append((p, p_t)) 46 | updates.append((t, t+1)) 47 | return updates 48 | 49 | class WeightNormLayer(lasagne.layers.Layer): 50 | def __init__(self, incoming, b=lasagne.init.Constant(0.), g=lasagne.init.Constant(1.), 51 | W=lasagne.init.Normal(0.05), train_g=False, init_stdv=1., nonlinearity=relu, **kwargs): 52 | super(WeightNormLayer, self).__init__(incoming, **kwargs) 53 | self.nonlinearity = nonlinearity 54 | self.init_stdv = init_stdv 55 | k = self.input_shape[1] 56 | if b is not None: 57 | self.b = self.add_param(b, (k,), name="b", regularizable=False) 58 | if g is not None: 59 | self.g = self.add_param(g, (k,), name="g", regularizable=False, trainable=train_g) 60 | if len(self.input_shape)==4: 61 | self.axes_to_sum = (0,2,3) 62 | self.dimshuffle_args = ['x',0,'x','x'] 63 | else: 64 | self.axes_to_sum = 0 65 | self.dimshuffle_args = ['x',0] 66 | 67 | # scale weights in layer below 68 | incoming.W_param = incoming.W 69 | #incoming.W_param.set_value(W.sample(incoming.W_param.get_value().shape)) 70 | if incoming.W_param.ndim==4: 71 | if isinstance(incoming, Deconv2DLayer): 72 | W_axes_to_sum = (0,2,3) 73 | W_dimshuffle_args = ['x',0,'x','x'] 74 | else: 75 | W_axes_to_sum = (1,2,3) 76 | W_dimshuffle_args = [0,'x','x','x'] 77 | else: 78 | W_axes_to_sum = 0 79 | W_dimshuffle_args = ['x',0] 80 | if g is not None: 81 | incoming.W = incoming.W_param * (self.g/T.sqrt(1e-6 + T.sum(T.square(incoming.W_param),axis=W_axes_to_sum))).dimshuffle(*W_dimshuffle_args) 82 | else: 83 | incoming.W = incoming.W_param / T.sqrt(1e-6 + T.sum(T.square(incoming.W_param),axis=W_axes_to_sum,keepdims=True)) 84 | 85 | def get_output_for(self, input, init=False, **kwargs): 86 | if init: 87 | m = T.mean(input, self.axes_to_sum) 88 | input -= m.dimshuffle(*self.dimshuffle_args) 89 | inv_stdv = self.init_stdv/T.sqrt(T.mean(T.square(input), self.axes_to_sum)) 90 | input *= inv_stdv.dimshuffle(*self.dimshuffle_args) 91 | self.init_updates = [(self.b, -m*inv_stdv), (self.g, self.g*inv_stdv)] 92 | elif hasattr(self,'b'): 93 | input += self.b.dimshuffle(*self.dimshuffle_args) 94 | 95 | return self.nonlinearity(input) 96 | 97 | def weight_norm(layer, **kwargs): 98 | nonlinearity = getattr(layer, 'nonlinearity', None) 99 | if nonlinearity is not None: 100 | layer.nonlinearity = lasagne.nonlinearities.identity 101 | if hasattr(layer, 'b'): 102 | del layer.params[layer.b] 103 | layer.b = None 104 | return WeightNormLayer(layer, nonlinearity=nonlinearity, **kwargs) 105 | 106 | class Deconv2DLayer(lasagne.layers.Layer): 107 | def __init__(self, incoming, target_shape, filter_size, stride=(2, 2), 108 | W=lasagne.init.Normal(0.05), b=lasagne.init.Constant(0.), nonlinearity=relu, **kwargs): 109 | super(Deconv2DLayer, self).__init__(incoming, **kwargs) 110 | self.target_shape = target_shape 111 | self.nonlinearity = (lasagne.nonlinearities.identity if nonlinearity is None else nonlinearity) 112 | self.filter_size = lasagne.layers.dnn.as_tuple(filter_size, 2) 113 | self.stride = lasagne.layers.dnn.as_tuple(stride, 2) 114 | self.target_shape = target_shape 115 | 116 | self.W_shape = (incoming.output_shape[1], target_shape[1], filter_size[0], filter_size[1]) 117 | self.W = self.add_param(W, self.W_shape, name="W") 118 | if b is not None: 119 | self.b = self.add_param(b, (target_shape[1],), name="b") 120 | else: 121 | self.b = None 122 | 123 | def get_output_for(self, input, **kwargs): 124 | op = T.nnet.abstract_conv.AbstractConv2d_gradInputs(imshp=self.target_shape, kshp=self.W_shape, subsample=self.stride, border_mode='half') 125 | activation = op(self.W, input, self.target_shape[2:]) 126 | 127 | if self.b is not None: 128 | activation += self.b.dimshuffle('x', 0, 'x', 'x') 129 | 130 | return self.nonlinearity(activation) 131 | 132 | def get_output_shape_for(self, input_shape): 133 | return self.target_shape 134 | 135 | # minibatch discrimination layer 136 | class MinibatchLayer(lasagne.layers.Layer): 137 | def __init__(self, incoming, num_kernels, dim_per_kernel=5, theta=lasagne.init.Normal(0.05), 138 | log_weight_scale=lasagne.init.Constant(0.), b=lasagne.init.Constant(-1.), **kwargs): 139 | super(MinibatchLayer, self).__init__(incoming, **kwargs) 140 | self.num_kernels = num_kernels 141 | num_inputs = int(np.prod(self.input_shape[1:])) 142 | self.theta = self.add_param(theta, (num_inputs, num_kernels, dim_per_kernel), name="theta") 143 | self.log_weight_scale = self.add_param(log_weight_scale, (num_kernels, dim_per_kernel), name="log_weight_scale") 144 | self.W = self.theta * (T.exp(self.log_weight_scale)/T.sqrt(T.sum(T.square(self.theta),axis=0))).dimshuffle('x',0,1) 145 | self.b = self.add_param(b, (num_kernels,), name="b") 146 | 147 | def get_output_shape_for(self, input_shape): 148 | return (input_shape[0], np.prod(input_shape[1:])+self.num_kernels) 149 | 150 | def get_output_for(self, input, init=False, **kwargs): 151 | if input.ndim > 2: 152 | # if the input has more than two dimensions, flatten it into a 153 | # batch of feature vectors. 154 | input = input.flatten(2) 155 | 156 | activation = T.tensordot(input, self.W, [[1], [0]]) 157 | abs_dif = (T.sum(abs(activation.dimshuffle(0,1,2,'x') - activation.dimshuffle('x',1,2,0)),axis=2) 158 | + 1e6 * T.eye(input.shape[0]).dimshuffle(0,'x',1)) 159 | 160 | if init: 161 | mean_min_abs_dif = 0.5 * T.mean(T.min(abs_dif, axis=2),axis=0) 162 | abs_dif /= mean_min_abs_dif.dimshuffle('x',0,'x') 163 | self.init_updates = [(self.log_weight_scale, self.log_weight_scale-T.log(mean_min_abs_dif).dimshuffle(0,'x'))] 164 | 165 | f = T.sum(T.exp(-abs_dif),axis=2) 166 | 167 | if init: 168 | mf = T.mean(f,axis=0) 169 | f -= mf.dimshuffle('x',0) 170 | self.init_updates.append((self.b, -mf)) 171 | else: 172 | f += self.b.dimshuffle('x',0) 173 | 174 | return T.concatenate([input, f], axis=1) 175 | 176 | class BatchNormLayer(lasagne.layers.Layer): 177 | def __init__(self, incoming, b=lasagne.init.Constant(0.), g=lasagne.init.Constant(1.), nonlinearity=relu, **kwargs): 178 | super(BatchNormLayer, self).__init__(incoming, **kwargs) 179 | self.nonlinearity = nonlinearity 180 | k = self.input_shape[1] 181 | if b is not None: 182 | self.b = self.add_param(b, (k,), name="b", regularizable=False) 183 | if g is not None: 184 | self.g = self.add_param(g, (k,), name="g", regularizable=False) 185 | self.avg_batch_mean = self.add_param(lasagne.init.Constant(0.), (k,), name="avg_batch_mean", regularizable=False, trainable=False) 186 | self.avg_batch_var = self.add_param(lasagne.init.Constant(1.), (k,), name="avg_batch_var", regularizable=False, trainable=False) 187 | if len(self.input_shape)==4: 188 | self.axes_to_sum = (0,2,3) 189 | self.dimshuffle_args = ['x',0,'x','x'] 190 | else: 191 | self.axes_to_sum = 0 192 | self.dimshuffle_args = ['x',0] 193 | 194 | def get_output_for(self, input, deterministic=False, **kwargs): 195 | if deterministic: 196 | norm_features = (input-self.avg_batch_mean.dimshuffle(*self.dimshuffle_args)) / T.sqrt(1e-6 + self.avg_batch_var).dimshuffle(*self.dimshuffle_args) 197 | else: 198 | batch_mean = T.mean(input,axis=self.axes_to_sum).flatten() 199 | centered_input = input-batch_mean.dimshuffle(*self.dimshuffle_args) 200 | batch_var = T.mean(T.square(centered_input),axis=self.axes_to_sum).flatten() 201 | batch_stdv = T.sqrt(1e-6 + batch_var) 202 | norm_features = centered_input / batch_stdv.dimshuffle(*self.dimshuffle_args) 203 | 204 | # BN updates 205 | new_m = 0.9*self.avg_batch_mean + 0.1*batch_mean 206 | new_v = 0.9*self.avg_batch_var + T.cast((0.1*input.shape[0])/(input.shape[0]-1),th.config.floatX)*batch_var 207 | self.bn_updates = [(self.avg_batch_mean, new_m), (self.avg_batch_var, new_v)] 208 | 209 | if hasattr(self, 'g'): 210 | activation = norm_features*self.g.dimshuffle(*self.dimshuffle_args) 211 | else: 212 | activation = norm_features 213 | if hasattr(self, 'b'): 214 | activation += self.b.dimshuffle(*self.dimshuffle_args) 215 | 216 | return self.nonlinearity(activation) 217 | 218 | def batch_norm(layer, b=lasagne.init.Constant(0.), g=lasagne.init.Constant(1.), **kwargs): 219 | """ 220 | adapted from https://gist.github.com/f0k/f1a6bd3c8585c400c190 221 | """ 222 | nonlinearity = getattr(layer, 'nonlinearity', None) 223 | if nonlinearity is not None: 224 | layer.nonlinearity = lasagne.nonlinearities.identity 225 | else: 226 | nonlinearity = lasagne.nonlinearities.identity 227 | if hasattr(layer, 'b'): 228 | del layer.params[layer.b] 229 | layer.b = None 230 | return BatchNormLayer(layer, b, g, nonlinearity=nonlinearity, **kwargs) 231 | 232 | class GaussianNoiseLayer(lasagne.layers.Layer): 233 | def __init__(self, incoming, sigma=0.1, **kwargs): 234 | super(GaussianNoiseLayer, self).__init__(incoming, **kwargs) 235 | self._srng = RandomStreams(lasagne.random.get_rng().randint(1, 2147462579)) 236 | self.sigma = sigma 237 | 238 | def get_output_for(self, input, deterministic=False, use_last_noise=False, **kwargs): 239 | if deterministic or self.sigma == 0: 240 | return input 241 | else: 242 | if not use_last_noise: 243 | self.noise = self._srng.normal(input.shape, avg=0.0, std=self.sigma) 244 | return input + self.noise 245 | 246 | 247 | # /////////// older code used for MNIST //////////// 248 | 249 | # weight normalization 250 | def l2normalize(layer, train_scale=True): 251 | W_param = layer.W 252 | s = W_param.get_value().shape 253 | if len(s)==4: 254 | axes_to_sum = (1,2,3) 255 | dimshuffle_args = [0,'x','x','x'] 256 | k = s[0] 257 | else: 258 | axes_to_sum = 0 259 | dimshuffle_args = ['x',0] 260 | k = s[1] 261 | layer.W_scale = layer.add_param(lasagne.init.Constant(1.), 262 | (k,), name="W_scale", trainable=train_scale, regularizable=False) 263 | layer.W = W_param * (layer.W_scale/T.sqrt(1e-6 + T.sum(T.square(W_param),axis=axes_to_sum))).dimshuffle(*dimshuffle_args) 264 | return layer 265 | 266 | # fully connected layer with weight normalization 267 | class DenseLayer(lasagne.layers.Layer): 268 | def __init__(self, incoming, num_units, theta=lasagne.init.Normal(0.1), b=lasagne.init.Constant(0.), 269 | weight_scale=lasagne.init.Constant(1.), train_scale=False, nonlinearity=relu, **kwargs): 270 | super(DenseLayer, self).__init__(incoming, **kwargs) 271 | self.nonlinearity = (lasagne.nonlinearities.identity if nonlinearity is None else nonlinearity) 272 | self.num_units = num_units 273 | num_inputs = int(np.prod(self.input_shape[1:])) 274 | self.theta = self.add_param(theta, (num_inputs, num_units), name="theta") 275 | self.weight_scale = self.add_param(weight_scale, (num_units,), name="weight_scale", trainable=train_scale) 276 | self.W = self.theta * (self.weight_scale/T.sqrt(T.sum(T.square(self.theta),axis=0))).dimshuffle('x',0) 277 | self.b = self.add_param(b, (num_units,), name="b") 278 | 279 | def get_output_shape_for(self, input_shape): 280 | return (input_shape[0], self.num_units) 281 | 282 | def get_output_for(self, input, init=False, deterministic=False, **kwargs): 283 | if input.ndim > 2: 284 | # if the input has more than two dimensions, flatten it into a 285 | # batch of feature vectors. 286 | input = input.flatten(2) 287 | 288 | activation = T.dot(input, self.W) 289 | 290 | if init: 291 | ma = T.mean(activation, axis=0) 292 | activation -= ma.dimshuffle('x',0) 293 | stdv = T.sqrt(T.mean(T.square(activation),axis=0)) 294 | activation /= stdv.dimshuffle('x',0) 295 | self.init_updates = [(self.weight_scale, self.weight_scale/stdv), (self.b, -ma/stdv)] 296 | else: 297 | activation += self.b.dimshuffle('x', 0) 298 | 299 | return self.nonlinearity(activation) 300 | 301 | from scipy import linalg 302 | class ZCA(object): 303 | def __init__(self, regularization=1e-5, x=None): 304 | self.regularization = regularization 305 | if x is not None: 306 | self.fit(x) 307 | 308 | def fit(self, x): 309 | s = x.shape 310 | x = x.copy().reshape((s[0],np.prod(s[1:]))) 311 | m = np.mean(x, axis=0) 312 | x -= m 313 | sigma = np.dot(x.T,x) / x.shape[0] 314 | U, S, V = linalg.svd(sigma) 315 | tmp = np.dot(U, np.diag(1./np.sqrt(S+self.regularization))) 316 | tmp2 = np.dot(U, np.diag(np.sqrt(S+self.regularization))) 317 | self.ZCA_mat = th.shared(np.dot(tmp, U.T).astype(th.config.floatX)) 318 | self.inv_ZCA_mat = th.shared(np.dot(tmp2, U.T).astype(th.config.floatX)) 319 | self.mean = th.shared(m.astype(th.config.floatX)) 320 | 321 | def apply(self, x): 322 | s = x.shape 323 | if isinstance(x, np.ndarray): 324 | return np.dot(x.reshape((s[0],np.prod(s[1:]))) - self.mean.get_value(), self.ZCA_mat.get_value()).reshape(s) 325 | elif isinstance(x, T.TensorVariable): 326 | return T.dot(x.flatten(2) - self.mean.dimshuffle('x',0), self.ZCA_mat).reshape(s) 327 | else: 328 | raise NotImplementedError("Whitening only implemented for numpy arrays or Theano TensorVariables") 329 | 330 | def invert(self, x): 331 | s = x.shape 332 | if isinstance(x, np.ndarray): 333 | return (np.dot(x.reshape((s[0],np.prod(s[1:]))), self.inv_ZCA_mat.get_value()) + self.mean.get_value()).reshape(s) 334 | elif isinstance(x, T.TensorVariable): 335 | return (T.dot(x.flatten(2), self.inv_ZCA_mat) + self.mean.dimshuffle('x',0)).reshape(s) 336 | else: 337 | raise NotImplementedError("Whitening only implemented for numpy arrays or Theano TensorVariables") -------------------------------------------------------------------------------- /svhn_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from six.moves import urllib 4 | from scipy.io import loadmat 5 | 6 | def maybe_download(data_dir): 7 | new_data_dir = os.path.join(data_dir, 'svhn') 8 | if not os.path.exists(new_data_dir): 9 | os.makedirs(new_data_dir) 10 | def _progress(count, block_size, total_size): 11 | sys.stdout.write('\r>> Downloading %.1f%%' % (float(count * block_size) / float(total_size) * 100.0)) 12 | sys.stdout.flush() 13 | filepath, _ = urllib.request.urlretrieve('http://ufldl.stanford.edu/housenumbers/train_32x32.mat', new_data_dir+'/train_32x32.mat', _progress) 14 | filepath, _ = urllib.request.urlretrieve('http://ufldl.stanford.edu/housenumbers/test_32x32.mat', new_data_dir+'/test_32x32.mat', _progress) 15 | 16 | def load(data_dir, subset='train'): 17 | maybe_download(data_dir) 18 | if subset=='train': 19 | train_data = loadmat(os.path.join(data_dir, 'svhn') + '/train_32x32.mat') 20 | trainx = train_data['X'] 21 | trainy = train_data['y'].flatten() 22 | trainy[trainy==10] = 0 23 | return trainx, trainy 24 | elif subset=='test': 25 | test_data = loadmat(os.path.join(data_dir, 'svhn') + '/test_32x32.mat') 26 | testx = test_data['X'] 27 | testy = test_data['y'].flatten() 28 | testy[testy==10] = 0 29 | return testx, testy 30 | else: 31 | raise NotImplementedError('subset should be either train or test') 32 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .paramgraphics import * 2 | from .others import * 3 | from .create_ssl_data import * -------------------------------------------------------------------------------- /utils/checkpoints.py: -------------------------------------------------------------------------------- 1 | import pickle, imp, time 2 | import gzip, logging, operator, os 3 | import os.path as osp 4 | import numpy as np 5 | from path import Path 6 | import coloredlogs 7 | 8 | 9 | def convert2dict(params): 10 | names = [par.name for par in params] 11 | assert len(names) == len(set(names)) 12 | 13 | param_dict = {par.name: par.get_value() for par in params} 14 | return param_dict 15 | 16 | 17 | def save_weights(fname, params, history=None): 18 | param_dict = convert2dict(params) 19 | 20 | logging.info('saving {} parameters to {}'.format(len(params), fname)) 21 | fname = Path(fname) 22 | 23 | filename, ext = osp.splitext(fname) 24 | history_file = osp.join(osp.dirname(fname), 'history.npy') 25 | np.save(history_file, history) 26 | logging.info("Save history to {}".format(history_file)) 27 | if ext == '.npy': 28 | np.save(filename + '.npy', param_dict) 29 | else: 30 | f = gzip.open(fname, 'wb') 31 | pickle.dump(param_dict, f, protocol=pickle.HIGHEST_PROTOCOL) 32 | f.close() 33 | 34 | 35 | def load_dict(fname): 36 | logging.info("Loading weights from {}".format(fname)) 37 | filename, ext = os.path.splitext(fname) 38 | if ext == '.npy': 39 | params_load = np.load(fname).item() 40 | else: 41 | f = gzip.open(fname, 'r') 42 | params_load = pickle.load(f) 43 | f.close() 44 | if type(params_load) is dict: 45 | param_dict = params_load 46 | else: 47 | param_dict = convert2dict(params_load) 48 | return param_dict 49 | 50 | def load_weights_trainable(fname, l_out): 51 | import lasagne 52 | params = lasagne.layers.get_all_params(l_out, trainable=True) 53 | names = [par.name for par in params] 54 | assert len(names) == len(set(names)) 55 | 56 | if type(fname) is list: 57 | param_dict = {} 58 | for name in fname: 59 | t_load = load_dict(name) 60 | param_dict.update(t_load) 61 | else: 62 | param_dict = load_dict(fname) 63 | 64 | for param in params: 65 | if param.name in param_dict: 66 | stored_shape = np.asarray(param_dict[param.name].shape) 67 | param_shape = np.asarray(param.get_value().shape) 68 | if not np.all(stored_shape == param_shape): 69 | warn_msg = 'shape mismatch:' 70 | warn_msg += '{} stored:{} new:{}'.format( 71 | param.name, stored_shape, param_shape) 72 | warn_msg += ', skipping' 73 | logging.warn(warn_msg) 74 | else: 75 | param.set_value(param_dict[param.name]) 76 | else: 77 | logging.warn('unable to load parameter {} from {}: No such variable.' 78 | .format(param.name, fname)) 79 | 80 | 81 | 82 | def load_weights(fname, l_out): 83 | import lasagne 84 | params = lasagne.layers.get_all_params(l_out) 85 | names = [par.name for par in params] 86 | assert len(names) == len(set(names)) 87 | 88 | if type(fname) is list: 89 | param_dict = {} 90 | for name in fname: 91 | t_load = load_dict(name) 92 | param_dict.update(t_load) 93 | else: 94 | param_dict = load_dict(fname) 95 | assign_weights(params, param_dict) 96 | 97 | def assign_weights(params, param_dict): 98 | for param in params: 99 | if param.name in param_dict: 100 | stored_shape = np.asarray(param_dict[param.name].shape) 101 | param_shape = np.asarray(param.get_value().shape) 102 | if not np.all(stored_shape == param_shape): 103 | warn_msg = 'shape mismatch:' 104 | warn_msg += '{} stored:{} new:{}'.format( 105 | param.name, stored_shape, param_shape) 106 | warn_msg += ', skipping' 107 | logging.warn(warn_msg) 108 | else: 109 | param.set_value(param_dict[param.name]) 110 | else: 111 | logging.warn('Unable to load parameter {}: No such variable.' 112 | .format(param.name)) 113 | 114 | 115 | def get_list_name(obj): 116 | if type(obj) is list: 117 | for i in range(len(obj)): 118 | if callable(obj[i]): 119 | obj[i] = obj[i].__name__ 120 | elif callable(obj): 121 | obj = obj.__name__ 122 | return obj 123 | 124 | 125 | # write commandline parameters to header of logfile 126 | def build_log_file(cfg): 127 | FORMAT="%(asctime)s;%(levelname)s|%(message)s" 128 | DATEF="%H-%M-%S" 129 | logging.basicConfig(formatter=FORMAT, level=logging.DEBUG) 130 | logger = logging.getLogger() 131 | logger.setLevel(logging.DEBUG) 132 | 133 | fh = logging.FileHandler(filename=os.path.join(cfg['outfolder'], 'logfile'+time.strftime("%m-%d")+'.log')) 134 | fh.setLevel(logging.DEBUG) 135 | formatter = logging.Formatter("%(asctime)s;%(levelname)s|%(message)s", "%H:%M:%S") 136 | fh.setFormatter(formatter) 137 | logger.addHandler(fh) 138 | 139 | LEVEL_STYLES = dict( 140 | debug=dict(color='magenta'), 141 | info=dict(color='green'), 142 | verbose=dict(), 143 | warning=dict(color='blue'), 144 | error=dict(color='yellow'), 145 | critical=dict(color='red',bold=True)) 146 | coloredlogs.install(level=logging.DEBUG, fmt=FORMAT, datefmt=DATEF, level_styles=LEVEL_STYLES) 147 | 148 | 149 | args_dict = cfg 150 | sorted_args = sorted(args_dict.items(), key=operator.itemgetter(0)) 151 | logging.info('######################################################') 152 | logging.info('# --Configurable Parameters In this Model--') 153 | for name, val in sorted_args: 154 | logging.info("# " + name + ":\t" + str(get_list_name(val))) 155 | logging.info('######################################################') 156 | 157 | 158 | def get_cfg(args): 159 | if args.cfg is not None: 160 | cfg = imp.load_source('config', args.cfg) 161 | else: 162 | raise Exception("The file path of config_file cannot be ignored") 163 | 164 | getmodel = cfg.get_model 165 | cfg = cfg.cfg 166 | args = vars(args).items() 167 | for name, val in args: 168 | cfg[name] = val 169 | 170 | cfg['outfolder'] = os.path.join(cfg['outfolder'], cfg['name']) 171 | res_out = cfg['outfolder'] 172 | if 'key_point' in cfg: 173 | res_out += '.'+ cfg['key_point'] 174 | if cfg['key_point'] in cfg: 175 | res_out += '-' + str(cfg[cfg['key_point']]) 176 | if 'notime' not in cfg or cfg['notime'] in [False, 'False', 'false', None, 'none', 'None']: 177 | res_out += '.' + time.strftime("%b-%d--%H-%M") 178 | 179 | res_out = os.path.realpath(res_out) 180 | if os.path.exists(res_out): 181 | tcount = 1 182 | while os.path.exists(res_out+'+'+str(tcount)): 183 | tcount += 1 184 | res_out += '+' + str(tcount) 185 | 186 | # print res_out 187 | os.makedirs(res_out) 188 | cfg['outfolder'] = res_out 189 | return cfg, getmodel 190 | -------------------------------------------------------------------------------- /utils/create_ssl_data.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Create semi-supervised datasets for different models 3 | ''' 4 | import numpy as np 5 | 6 | def create_ssl_data(x, y, n_classes, n_labelled, seed): 7 | # 'x': data matrix, nxk 8 | # 'y': label vector, n 9 | # 'n_classes': number of classes 10 | # 'n_labelled': number of labelled data 11 | # 'seed': random seed 12 | 13 | # check input 14 | if n_labelled%n_classes != 0: 15 | print n_labelled 16 | print n_classes 17 | raise("n_labelled (wished number of labelled samples) not divisible by n_classes (number of classes)") 18 | n_labels_per_class = n_labelled/n_classes 19 | 20 | rng = np.random.RandomState(seed) 21 | index = rng.permutation(x.shape[0]) 22 | x = x[index] 23 | y = y[index] 24 | 25 | # select first several data per class 26 | data_labelled = [0]*n_classes 27 | index_labelled = [] 28 | index_unlabelled = [] 29 | for i in xrange(x.shape[0]): 30 | if data_labelled[y[i]] < n_labels_per_class: 31 | data_labelled[y[i]] += 1 32 | index_labelled.append(i) 33 | else: 34 | index_unlabelled.append(i) 35 | 36 | x_labelled = x[index_labelled] 37 | y_labelled = y[index_labelled] 38 | x_unlabelled = x[index_unlabelled] 39 | y_unlabelled = y[index_unlabelled] 40 | return x_labelled, y_labelled, x_unlabelled, y_unlabelled 41 | 42 | 43 | def create_ssl_data_subset(x, y, n_classes, n_labelled, n_labelled_per_time, seed): 44 | assert n_labelled%n_labelled_per_time==0 45 | times = n_labelled/n_labelled_per_time 46 | x_labelled, y_labelled, x_unlabelled, y_unlabelled = create_ssl_data(x, y, n_classes, n_labelled_per_time, seed) 47 | while (times > 1): 48 | x_labelled_new, y_labelled_new, x_unlabelled, y_unlabelled = create_ssl_data(x_unlabelled, y_unlabelled, n_classes, n_labelled_per_time, seed) 49 | x_labelled = np.vstack((x_labelled, x_labelled_new)) 50 | y_labelled = np.hstack((y_labelled, y_labelled_new)) 51 | times -= 1 52 | return x_labelled, y_labelled, x_unlabelled, y_unlabelled -------------------------------------------------------------------------------- /utils/others.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import theano.tensor as T 3 | import theano, lasagne 4 | 5 | 6 | def get_pad(pad): 7 | if pad not in ['same', 'valid', 'full']: 8 | pad = tuple(map(int, pad.split('-'))) 9 | return pad 10 | 11 | def get_pad_list(pad_list): 12 | re_list = [] 13 | for p in pad_list: 14 | re_list.append(get_pad(p)) 15 | return re_list 16 | 17 | # nonlinearities 18 | def get_nonlin(nonlin): 19 | if nonlin == 'rectify': 20 | return lasagne.nonlinearities.rectify 21 | elif nonlin == 'leaky_rectify': 22 | return lasagne.nonlinearities.LeakyRectify(0.1) 23 | elif nonlin == 'tanh': 24 | return lasagne.nonlinearities.tanh 25 | elif nonlin == 'sigmoid': 26 | return lasagne.nonlinearities.sigmoid 27 | elif nonlin == 'maxout': 28 | return 'maxout' 29 | elif nonlin == 'none': 30 | return lasagne.nonlinearities.identity 31 | else: 32 | raise ValueError('invalid non-linearity \'' + nonlin + '\'') 33 | 34 | def get_nonlin_list(nonlin_list): 35 | re_list = [] 36 | for n in nonlin_list: 37 | re_list.append(get_nonlin(n)) 38 | return re_list 39 | 40 | def bernoullisample(x): 41 | return np.random.binomial(1,x,size=x.shape).astype(theano.config.floatX) 42 | 43 | def array2file_2D(array,logfile): 44 | assert len(array.shape) == 2, array.shape 45 | with open(logfile,'a') as f: 46 | for i in xrange(array.shape[0]): 47 | for j in xrange(array.shape[1]): 48 | f.write(str(array[i][j])+' ') 49 | f.write('\n') 50 | 51 | def printarray_2D(array, precise=2): 52 | assert len(array.shape) == 2, array.shape 53 | format = '%.'+str(precise)+'f' 54 | for i in xrange(array.shape[0]): 55 | for j in xrange(array.shape[1]): 56 | print format %array[i][j], 57 | print -------------------------------------------------------------------------------- /utils/paramgraphics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from scipy.misc import imsave 4 | 5 | def scale_max_min(images, max_p, min_p): 6 | # scale the images according to the max and min 7 | # images f x n, column major 8 | ret = np.zeros(images.shape) 9 | for i in xrange(images.shape[1]): 10 | # clips at first 11 | tmp = np.clip(images[:,i], min_p[i], max_p[i]) 12 | # scale 13 | ret[:,i] = (tmp - min_p[i]) / (max_p[i] - min_p[i]) 14 | 15 | return ret 16 | 17 | def scale_to_unit_interval(ndar, eps=1e-8): 18 | """ Scales all values in the ndarray ndar to be between 0 and 1 """ 19 | ndar = ndar.copy() 20 | ndar -= ndar.min() 21 | ndar *= 1.0 / (ndar.max() + eps) 22 | return ndar 23 | 24 | def tile_raster_images(X, img_shape, tile_shape, tile_spacing=(0, 0), 25 | scale=True, 26 | output_pixel_vals=True, 27 | colorImg=False): 28 | """ 29 | Transform an array with one flattened image per row, into an array in 30 | which images are reshaped and layed out like tiles on a floor. 31 | 32 | This function is useful for visualizing datasets whose rows are images, 33 | and also columns of matrices for transforming those rows 34 | (such as the first layer of a neural net). 35 | 36 | :type X: a 2-D ndarray or a tuple of 4 channels, elements of which can 37 | be 2-D ndarrays or None; 38 | :param X: a 2-D array in which every row is a flattened image. 39 | 40 | :type img_shape: tuple; (height, width) 41 | :param img_shape: the original shape of each image 42 | 43 | :type tile_shape: tuple; (rows, cols) 44 | :param tile_shape: the number of images to tile (rows, cols) 45 | 46 | :param output_pixel_vals: if output should be pixel values (i.e. int8 47 | values) or floats 48 | 49 | :param scale_rows_to_unit_interval: if the values need to be scaled before 50 | being plotted to [0,1] or not 51 | 52 | 53 | :returns: array suitable for viewing as an image. 54 | """ 55 | X = X * 1.0 # converts ints to floats 56 | 57 | if colorImg: 58 | channelSize = X.shape[1]/3 59 | X = (X[:,0:channelSize], X[:,channelSize:2*channelSize], X[:,2*channelSize:3*channelSize], None) 60 | 61 | assert len(img_shape) == 2 62 | assert len(tile_shape) == 2 63 | assert len(tile_spacing) == 2 64 | 65 | # The expression below can be re-written in a more C style as 66 | # follows : 67 | # 68 | # out_shape = [0,0] 69 | # out_shape[0] = (img_shape[0] + tile_spacing[0]) * tile_shape[0] - 70 | # tile_spacing[0] 71 | # out_shape[1] = (img_shape[1] + tile_spacing[1]) * tile_shape[1] - 72 | # tile_spacing[1] 73 | out_shape = [(ishp + tsp) * tshp - tsp for ishp, tshp, tsp 74 | in zip(img_shape, tile_shape, tile_spacing)] 75 | 76 | if isinstance(X, tuple): 77 | assert len(X) == 4 78 | # Create an output np ndarray to store the image 79 | if output_pixel_vals: 80 | out_array = np.zeros((out_shape[0], out_shape[1], 4), dtype='uint8') 81 | else: 82 | out_array = np.zeros((out_shape[0], out_shape[1], 4), dtype=X.dtype) 83 | 84 | #colors default to 0, alpha defaults to 1 (opaque) 85 | if output_pixel_vals: 86 | channel_defaults = [0, 0, 0, 255] 87 | else: 88 | channel_defaults = [0., 0., 0., 1.] 89 | 90 | 91 | for i in xrange(4): 92 | if X[i] is None: 93 | # if channel is None, fill it with zeros of the correct 94 | # dtype 95 | out_array[:, :, i] = np.zeros(out_shape, 96 | dtype='uint8' if output_pixel_vals else out_array.dtype 97 | ) + channel_defaults[i] 98 | else: 99 | # use a recurrent call to compute the channel and store it 100 | # in the output 101 | xi = X[i] 102 | if scale: 103 | xi = (X[i] - X[i].min()) / (X[i].max() - X[i].min()) 104 | out_array[:, :, i] = tile_raster_images(xi, img_shape, tile_shape, tile_spacing, False, output_pixel_vals) 105 | 106 | 107 | return out_array 108 | 109 | else: 110 | # if we are dealing with only one channel 111 | H, W = img_shape 112 | Hs, Ws = tile_spacing 113 | 114 | # generate a matrix to store the output 115 | out_array = np.zeros(out_shape, dtype='uint8' if output_pixel_vals else X.dtype) 116 | 117 | 118 | for tile_row in xrange(tile_shape[0]): 119 | for tile_col in xrange(tile_shape[1]): 120 | if tile_row * tile_shape[1] + tile_col < X.shape[0]: 121 | if scale: 122 | # if we should scale values to be between 0 and 1 123 | # do this by calling the `scale_to_unit_interval` 124 | # function 125 | tmp = X[tile_row * tile_shape[1] + tile_col].reshape(img_shape) 126 | this_img = scale_to_unit_interval(tmp) 127 | else: 128 | this_img = X[tile_row * tile_shape[1] + tile_col].reshape(img_shape) 129 | # add the slice to the corresponding position in the 130 | # output array 131 | out_array[ 132 | tile_row * (H+Hs): tile_row * (H + Hs) + H, 133 | tile_col * (W+Ws): tile_col * (W + Ws) + W 134 | ] \ 135 | = this_img * (255 if output_pixel_vals else 1) 136 | return out_array 137 | 138 | # Matrix to image 139 | def mat_to_img(w, dim_input, scale=False, colorImg=False, tile_spacing=(1,1), tile_shape=0, save_path=None): 140 | if tile_shape == 0: 141 | rowscols = int(w.shape[1]**0.5) 142 | tile_shape = (rowscols,rowscols) 143 | w = w[:, 0:rowscols*rowscols] 144 | imgs = tile_raster_images(X=w.T, img_shape=dim_input, tile_shape=tile_shape, tile_spacing=tile_spacing, scale=scale, colorImg=colorImg) 145 | if save_path is not None: 146 | imsave(save_path, imgs) 147 | return imgs -------------------------------------------------------------------------------- /x2y_yz2x_xy2p_ssl_cifar10.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code implements a triple GAN for semi-supervised learning on CIFAR10 3 | ''' 4 | import os, sys, time 5 | import numpy as np 6 | 7 | import scipy 8 | from collections import OrderedDict 9 | import pickle 10 | 11 | from lasagne.layers import InputLayer, ReshapeLayer, FlattenLayer, Upscale2DLayer, MaxPool2DLayer, DropoutLayer, ConcatLayer, DenseLayer, NINLayer 12 | from lasagne.layers import GaussianNoiseLayer, Conv2DLayer, Pool2DLayer, GlobalPoolLayer, NonlinearityLayer, FeaturePoolLayer, DimshuffleLayer, ElemwiseSumLayer 13 | from lasagne.utils import floatX 14 | from zca_bn import ZCA 15 | from zca_bn import mean_only_bn as WN 16 | 17 | import gzip, os, cPickle, time, math, argparse, shutil, sys 18 | 19 | import numpy as np 20 | import theano, lasagne 21 | import theano.tensor as T 22 | import lasagne.layers as ll 23 | import lasagne.nonlinearities as ln 24 | from lasagne.layers import dnn 25 | import nn 26 | from lasagne.init import Normal 27 | from theano.sandbox.rng_mrg import MRG_RandomStreams 28 | import cifar10_data 29 | 30 | from layers.merge import ConvConcatLayer, MLPConcatLayer 31 | from layers.deconv import Deconv2DLayer 32 | 33 | from components.shortcuts import convlayer, mlplayer 34 | from components.objectives import categorical_crossentropy_ssl_separated, maximum_mean_discripancy, categorical_crossentropy, feature_matching 35 | from utils.create_ssl_data import create_ssl_data, create_ssl_data_subset 36 | from utils.others import get_nonlin_list, get_pad_list, bernoullisample, printarray_2D, array2file_2D 37 | import utils.paramgraphics as paramgraphics 38 | 39 | def build_network(): 40 | conv_defs = { 41 | 'W': lasagne.init.HeNormal('relu'), 42 | 'b': lasagne.init.Constant(0.0), 43 | 'filter_size': (3, 3), 44 | 'stride': (1, 1), 45 | 'nonlinearity': lasagne.nonlinearities.LeakyRectify(0.1) 46 | } 47 | 48 | nin_defs = { 49 | 'W': lasagne.init.HeNormal('relu'), 50 | 'b': lasagne.init.Constant(0.0), 51 | 'nonlinearity': lasagne.nonlinearities.LeakyRectify(0.1) 52 | } 53 | 54 | dense_defs = { 55 | 'W': lasagne.init.HeNormal(1.0), 56 | 'b': lasagne.init.Constant(0.0), 57 | 'nonlinearity': lasagne.nonlinearities.softmax 58 | } 59 | 60 | wn_defs = { 61 | 'momentum': .999 62 | } 63 | 64 | net = InputLayer ( name='input', shape=(None, 3, 32, 32)) 65 | net = GaussianNoiseLayer(net, name='noise', sigma=.15) 66 | net = WN(Conv2DLayer (net, name='conv1a', num_filters=128, pad='same', **conv_defs), **wn_defs) 67 | net = WN(Conv2DLayer (net, name='conv1b', num_filters=128, pad='same', **conv_defs), **wn_defs) 68 | net = WN(Conv2DLayer (net, name='conv1c', num_filters=128, pad='same', **conv_defs), **wn_defs) 69 | net = MaxPool2DLayer (net, name='pool1', pool_size=(2, 2)) 70 | net = DropoutLayer (net, name='drop1', p=.5) 71 | net = WN(Conv2DLayer (net, name='conv2a', num_filters=256, pad='same', **conv_defs), **wn_defs) 72 | net = WN(Conv2DLayer (net, name='conv2b', num_filters=256, pad='same', **conv_defs), **wn_defs) 73 | net = WN(Conv2DLayer (net, name='conv2c', num_filters=256, pad='same', **conv_defs), **wn_defs) 74 | net = MaxPool2DLayer (net, name='pool2', pool_size=(2, 2)) 75 | net = DropoutLayer (net, name='drop2', p=.5) 76 | net = WN(Conv2DLayer (net, name='conv3a', num_filters=512, pad=0, **conv_defs), **wn_defs) 77 | net = WN(NINLayer (net, name='conv3b', num_units=256, **nin_defs), **wn_defs) 78 | net = WN(NINLayer (net, name='conv3c', num_units=128, **nin_defs), **wn_defs) 79 | net = GlobalPoolLayer (net, name='pool3') 80 | net = WN(DenseLayer (net, name='dense', num_units=10, **dense_defs), **wn_defs) 81 | 82 | return net 83 | 84 | def rampup(epoch): 85 | if epoch < 80: 86 | p = max(0.0, float(epoch)) / float(80) 87 | p = 1.0 - p 88 | return math.exp(-p*p*5.0) 89 | else: 90 | return 1.0 91 | 92 | def rampdown(epoch): 93 | if epoch >= (300 - 50): 94 | ep = (epoch - (300 - 50)) * 0.5 95 | return math.exp(-(ep * ep) / 50) 96 | else: 97 | return 1.0 98 | 99 | def robust_adam(loss, params, learning_rate, beta1=0.9, beta2=0.999, epsilon=1.0e-8): 100 | # Convert NaNs to zeros. 101 | def clear_nan(x): 102 | return T.switch(T.isnan(x), np.float32(0.0), x) 103 | 104 | new = OrderedDict() 105 | pg = zip(params, lasagne.updates.get_or_compute_grads(loss, params)) 106 | t = theano.shared(lasagne.utils.floatX(0.)) 107 | 108 | new[t] = t + 1.0 109 | coef = learning_rate * T.sqrt(1.0 - beta2**new[t]) / (1.0 - beta1**new[t]) 110 | for p, g in pg: 111 | value = p.get_value(borrow=True) 112 | m = theano.shared(np.zeros(value.shape, dtype=value.dtype), broadcastable=p.broadcastable) 113 | v = theano.shared(np.zeros(value.shape, dtype=value.dtype), broadcastable=p.broadcastable) 114 | new[m] = clear_nan(beta1 * m + (1.0 - beta1) * g) 115 | new[v] = clear_nan(beta2 * v + (1.0 - beta2) * g**2) 116 | new[p] = clear_nan(p - coef * new[m] / (T.sqrt(new[v]) + epsilon)) 117 | 118 | return new 119 | 120 | ''' 121 | parameters 122 | ''' 123 | # global 124 | parser = argparse.ArgumentParser() 125 | parser.add_argument("-key", type=str, default=argparse.SUPPRESS) 126 | parser.add_argument("-ssl_seed", type=int, default=1) 127 | parser.add_argument("-nlabeled", type=int, default=4000) 128 | parser.add_argument("-cla_g", type=float, default=0.1) 129 | parser.add_argument("-oldmodel", type=str, default=argparse.SUPPRESS) 130 | args = parser.parse_args() 131 | args = vars(args).items() 132 | cfg = {} 133 | for name, val in args: 134 | cfg[name] = val 135 | 136 | filename_script=os.path.basename(os.path.realpath(__file__)) 137 | outfolder=os.path.join("results-ssl", os.path.splitext(filename_script)[0]) 138 | outfolder+='.' 139 | for item in cfg: 140 | if item is not 'oldmodel': 141 | outfolder += item+str(cfg[item])+'.' 142 | else: 143 | outfolder += 'oldmodel.' 144 | outfolder+=str(int(time.time())) 145 | if not os.path.exists(outfolder): 146 | os.makedirs(outfolder) 147 | sample_path = os.path.join(outfolder, 'sample') 148 | os.makedirs(sample_path) 149 | logfile=os.path.join(outfolder, 'logfile.log') 150 | shutil.copy(os.path.realpath(__file__), os.path.join(outfolder, filename_script)) 151 | # fixed random seeds 152 | ssl_data_seed=cfg['ssl_seed'] 153 | num_labelled=cfg['nlabeled'] 154 | alpha_cla_g=cfg['cla_g'] 155 | print ssl_data_seed, num_labelled 156 | 157 | seed=1234 158 | rng=np.random.RandomState(seed) 159 | theano_rng=MRG_RandomStreams(rng.randint(2 ** 15)) 160 | lasagne.random.set_rng(np.random.RandomState(rng.randint(2 ** 15))) 161 | 162 | # flags 163 | valid_flag=False 164 | # C 165 | alpha_cla_adv = 0.01 166 | alpha_cla=1. 167 | scaled_unsup_weight_max = 100.0 168 | # G 169 | n_z=100 170 | epoch_cla_g=200 171 | # D 172 | noise_D_data=.3 173 | noise_D=.5 174 | # optimization 175 | b1_g=.5 # mom1 in Adam 176 | b1_d=.5 177 | batch_size_g=200 178 | batch_size_l_c=100 179 | batch_size_u_c=100 180 | batch_size_u_d=160 181 | batch_size_l_d=200-batch_size_u_d 182 | lr=3e-4 183 | cla_lr=3e-3 184 | num_epochs=1000 185 | anneal_lr_epoch=300 186 | anneal_lr_every_epoch=1 187 | anneal_lr_factor_cla=.99 188 | anneal_lr_factor=.995 189 | # data dependent 190 | gen_final_non=ln.tanh 191 | num_classes=10 192 | dim_input=(32,32) 193 | in_channels=3 194 | colorImg=True 195 | generation_scale=True 196 | z_generated=num_classes 197 | # evaluation 198 | vis_epoch=10 199 | eval_epoch=1 200 | batch_size_eval=200 201 | 202 | 203 | ''' 204 | data 205 | ''' 206 | def rescale(mat): 207 | return np.cast[theano.config.floatX](mat) 208 | 209 | train_x, train_y = cifar10_data.load('/home/chongxuan/mfs/data/cifar10/','train') 210 | eval_x, eval_y = cifar10_data.load('/home/chongxuan/mfs/data/cifar10/','test') 211 | 212 | train_y = np.int32(train_y) 213 | eval_y = np.int32(eval_y) 214 | train_x = rescale(train_x) 215 | eval_x = rescale(eval_x) 216 | x_unlabelled = train_x.copy() 217 | 218 | rng_data = np.random.RandomState(ssl_data_seed) 219 | inds = rng_data.permutation(train_x.shape[0]) 220 | train_x = train_x[inds] 221 | train_y = train_y[inds] 222 | x_labelled = [] 223 | y_labelled = [] 224 | for j in range(num_classes): 225 | x_labelled.append(train_x[train_y==j][:num_labelled/num_classes]) 226 | y_labelled.append(train_y[train_y==j][:num_labelled/num_classes]) 227 | x_labelled = np.concatenate(x_labelled, axis=0) 228 | y_labelled = np.concatenate(y_labelled, axis=0) 229 | del train_x 230 | 231 | if True: 232 | print 'Size of training data', x_labelled.shape[0], x_unlabelled.shape[0] 233 | y_order = np.argsort(y_labelled) 234 | _x_mean = x_labelled[y_order] 235 | image = paramgraphics.mat_to_img(_x_mean.T, dim_input, tile_shape=(num_classes, num_labelled/num_classes), colorImg=colorImg, scale=generation_scale, save_path=os.path.join(outfolder, 'x_l_'+str(ssl_data_seed)+'_triple-gan.png')) 236 | 237 | n_batches_train_u_c = x_unlabelled.shape[0] / batch_size_u_c 238 | n_batches_train_l_c = x_labelled.shape[0] / batch_size_l_c 239 | n_batches_train_u_d = x_unlabelled.shape[0] / batch_size_u_d 240 | n_batches_train_l_d = x_labelled.shape[0] / batch_size_l_d 241 | n_batches_train_g = x_unlabelled.shape[0] / batch_size_g 242 | n_batches_eval = eval_x.shape[0] / batch_size_eval 243 | 244 | 245 | ''' 246 | models 247 | ''' 248 | # symbols 249 | sym_z_image = T.tile(theano_rng.uniform((z_generated, n_z)), (num_classes, 1)) 250 | sym_z_rand = theano_rng.uniform(size=(batch_size_g, n_z)) 251 | sym_x_u = T.tensor4() 252 | sym_x_u_d = T.tensor4() 253 | sym_x_u_g = T.tensor4() 254 | sym_x_l = T.tensor4() 255 | sym_y = T.ivector() 256 | sym_y_g = T.ivector() 257 | sym_x_eval = T.tensor4() 258 | sym_lr = T.scalar() 259 | sym_alpha_cla_g = T.scalar() 260 | sym_alpha_unlabel_entropy = T.scalar() 261 | sym_alpha_unlabel_average = T.scalar() 262 | 263 | # te 264 | sym_lr_cla = T.scalar('separate_lr') 265 | sym_x_u_rep = T.tensor4('two_pass') 266 | sym_unsup_weight = T.scalar('unsup_weight') 267 | sym_b_c = T.scalar('adam_beta1') 268 | 269 | 270 | shared_labeled = theano.shared(x_labelled, borrow=True) 271 | shared_labely = theano.shared(y_labelled, borrow=True) 272 | shared_unlabel = theano.shared(x_unlabelled, borrow=True) 273 | slice_x_u_g = T.ivector() 274 | slice_x_u_d = T.ivector() 275 | slice_x_u_c = T.ivector() 276 | 277 | classifier = build_network() 278 | 279 | # generator y2x: p_g(x, y) = p(y) p_g(x | y) where x = G(z, y), z follows p_g(z) 280 | gen_in_z = ll.InputLayer(shape=(None, n_z)) 281 | gen_in_y = ll.InputLayer(shape=(None,)) 282 | gen_layers = [gen_in_z] 283 | gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-00')) 284 | gen_layers.append(nn.batch_norm(ll.DenseLayer(gen_layers[-1], num_units=4*4*512, W=Normal(0.05), nonlinearity=nn.relu, name='gen-01'), g=None, name='gen-02')) 285 | gen_layers.append(ll.ReshapeLayer(gen_layers[-1], (-1,512,4,4), name='gen-03')) 286 | gen_layers.append(ConvConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-10')) 287 | gen_layers.append(nn.batch_norm(nn.Deconv2DLayer(gen_layers[-1], (None,256,8,8), (5,5), W=Normal(0.05), nonlinearity=nn.relu, name='gen-11'), g=None, name='gen-12')) # 4 -> 8 288 | gen_layers.append(ConvConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-20')) 289 | gen_layers.append(nn.batch_norm(nn.Deconv2DLayer(gen_layers[-1], (None,128,16,16), (5,5), W=Normal(0.05), nonlinearity=nn.relu, name='gen-21'), g=None, name='gen-22')) # 8 -> 16 290 | gen_layers.append(ConvConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-30')) 291 | gen_layers.append(nn.weight_norm(nn.Deconv2DLayer(gen_layers[-1], (None,3,32,32), (5,5), W=Normal(0.05), nonlinearity=gen_final_non, name='gen-31'), train_g=True, init_stdv=0.1, name='gen-32')) # 16 -> 32 292 | 293 | # discriminator xy2p: test a pair of input comes from p(x, y) instead of p_c or p_g 294 | dis_in_x = ll.InputLayer(shape=(None, in_channels) + dim_input) 295 | dis_in_y = ll.InputLayer(shape=(None,)) 296 | dis_layers = [dis_in_x] 297 | dis_layers.append(ll.DropoutLayer(dis_layers[-1], p=0.2, name='dis-00')) 298 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-01')) 299 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 32, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-02'), name='dis-03')) 300 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-20')) 301 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 32, (3,3), pad=1, stride=2, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-21'), name='dis-22')) 302 | dis_layers.append(ll.DropoutLayer(dis_layers[-1], p=0.2, name='dis-23')) 303 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-30')) 304 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 64, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-31'), name='dis-32')) 305 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-40')) 306 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 64, (3,3), pad=1, stride=2, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-41'), name='dis-42')) 307 | dis_layers.append(ll.DropoutLayer(dis_layers[-1], p=0.2, name='dis-43')) 308 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-50')) 309 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 128, (3,3), pad=0, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-51'), name='dis-52')) 310 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-60')) 311 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 128, (3,3), pad=0, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-61'), name='dis-62')) 312 | dis_layers.append(ll.GlobalPoolLayer(dis_layers[-1], name='dis-63')) 313 | dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-70')) 314 | dis_layers.append(nn.weight_norm(ll.DenseLayer(dis_layers[-1], num_units=1, W=Normal(0.05), nonlinearity=ln.sigmoid, name='dis-71'), train_g=True, init_stdv=0.1, name='dis-72')) 315 | 316 | 317 | ''' 318 | objectives 319 | ''' 320 | # zca 321 | whitener = ZCA(x=x_unlabelled) 322 | sym_x_l_zca = whitener.apply(sym_x_l) 323 | sym_x_eval_zca = whitener.apply(sym_x_eval) 324 | sym_x_u_zca = whitener.apply(sym_x_u) 325 | sym_x_u_rep_zca = whitener.apply(sym_x_u_rep) 326 | sym_x_u_d_zca = whitener.apply(sym_x_u_d) 327 | 328 | # init 329 | lasagne.layers.get_output(classifier, sym_x_u_zca, init=True) 330 | init_updates = [u for l in lasagne.layers.get_all_layers(classifier) for u in getattr(l, 'init_updates', [])] 331 | init_fn = theano.function([sym_x_u], [], updates=init_updates) 332 | 333 | # outputs 334 | gen_out_x = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_rand}, deterministic=False) 335 | gen_out_x_zca = whitener.apply(gen_out_x) 336 | cla_out_y_l = ll.get_output(classifier, sym_x_l_zca, deterministic=False) 337 | cla_out_y_eval = ll.get_output(classifier, sym_x_eval_zca, deterministic=True) 338 | cla_out_y = ll.get_output(classifier, sym_x_u_zca, deterministic=False) 339 | cla_out_y_rep = ll.get_output(classifier, sym_x_u_rep_zca, deterministic=False) 340 | bn_updates = [u for l in lasagne.layers.get_all_layers(classifier) for u in getattr(l, 'bn_updates', [])] 341 | 342 | cla_out_y_d = ll.get_output(classifier, sym_x_u_d_zca, deterministic=False) 343 | cla_out_y_d_hard = cla_out_y_d.argmax(axis=1) 344 | cla_out_y_g = ll.get_output(classifier, gen_out_x_zca, deterministic=False) 345 | 346 | dis_out_p = ll.get_output(dis_layers[-1], {dis_in_x:T.concatenate([sym_x_l,sym_x_u_d], axis=0),dis_in_y:T.concatenate([sym_y,cla_out_y_d_hard], axis=0)}, deterministic=False) 347 | dis_out_p_g = ll.get_output(dis_layers[-1], {dis_in_x:gen_out_x,dis_in_y:sym_y_g}, deterministic=False) 348 | # argmax 349 | cla_out_y_hard = cla_out_y.argmax(axis=1) 350 | dis_out_p_c = ll.get_output(dis_layers[-1], {dis_in_x:sym_x_u,dis_in_y:cla_out_y_hard}, deterministic=False) 351 | 352 | image = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_image}, deterministic=False) # for generation 353 | 354 | accurracy_eval = (lasagne.objectives.categorical_accuracy(cla_out_y_eval, sym_y)) # for evaluation 355 | accurracy_eval = accurracy_eval.mean() 356 | 357 | # costs 358 | bce = lasagne.objectives.binary_crossentropy 359 | 360 | dis_cost_p = bce(dis_out_p, T.ones(dis_out_p.shape)).mean() # D distincts p 361 | dis_cost_p_g = bce(dis_out_p_g, T.zeros(dis_out_p_g.shape)).mean() # D distincts p_g 362 | gen_cost_p_g = bce(dis_out_p_g, T.ones(dis_out_p_g.shape)).mean() # G fools D 363 | 364 | 365 | dis_cost_p_c = bce(dis_out_p_c, T.zeros(dis_out_p_c.shape)) # D distincts p_c 366 | 367 | # argmax 368 | p_cla_max = cla_out_y.max(axis=1) 369 | cla_cost_p_c = bce(dis_out_p_c, T.ones(dis_out_p_c.shape)) # C fools D 370 | cla_cost_p_c = (cla_cost_p_c*p_cla_max).mean() 371 | dis_cost_p_c = dis_cost_p_c.mean() 372 | 373 | 374 | cla_cost_l = T.mean(lasagne.objectives.categorical_crossentropy(cla_out_y_l, sym_y), dtype=theano.config.floatX, acc_dtype=theano.config.floatX) 375 | 376 | cla_cost_u = sym_unsup_weight * T.mean(lasagne.objectives.squared_error(cla_out_y, cla_out_y_rep), dtype=theano.config.floatX, acc_dtype=theano.config.floatX) 377 | 378 | cla_cost_cla_g = categorical_crossentropy(predictions=cla_out_y_g, targets=sym_y_g) 379 | 380 | dis_cost = dis_cost_p + .5*dis_cost_p_g + .5*dis_cost_p_c 381 | gen_cost = .5*gen_cost_p_g 382 | cla_cost = alpha_cla_adv * .5 * cla_cost_p_c + alpha_cla * (cla_cost_l + cla_cost_u) + sym_alpha_cla_g * cla_cost_cla_g 383 | # fast 384 | cla_cost_fast = alpha_cla_adv * .5 * cla_cost_p_c + alpha_cla*(cla_cost_l + cla_cost_u) 385 | 386 | dis_cost_list=[dis_cost, dis_cost_p, .5*dis_cost_p_g, .5*dis_cost_p_c] 387 | gen_cost_list=[gen_cost,] 388 | cla_cost_list=[cla_cost, alpha_cla_adv * .5 * cla_cost_p_c, alpha_cla*cla_cost_l, alpha_cla*cla_cost_u, sym_alpha_cla_g*cla_cost_cla_g] 389 | # fast 390 | cla_cost_list_fast=[cla_cost_fast, alpha_cla_adv * .5 * cla_cost_p_c, alpha_cla*cla_cost_l, alpha_cla*cla_cost_u, ] 391 | 392 | # updates of D 393 | dis_params = ll.get_all_params(dis_layers, trainable=True) 394 | dis_grads = T.grad(dis_cost, dis_params) 395 | dis_updates = lasagne.updates.adam(dis_grads, dis_params, beta1=b1_d, learning_rate=sym_lr) 396 | 397 | # updates of G 398 | gen_params = ll.get_all_params(gen_layers, trainable=True) 399 | gen_grads = T.grad(gen_cost, gen_params) 400 | gen_updates = lasagne.updates.adam(gen_grads, gen_params, beta1=b1_g, learning_rate=sym_lr) 401 | 402 | # updates of C 403 | cla_params = ll.get_all_params(classifier, trainable=True) 404 | cla_updates_ = robust_adam(cla_cost, cla_params, learning_rate=sym_lr_cla, beta1=sym_b_c, beta2=.999, epsilon=1e-8) 405 | 406 | # fast updates of C 407 | cla_params = ll.get_all_params(classifier, trainable=True) 408 | cla_updates_fast_ = robust_adam(cla_cost_fast, cla_params, learning_rate=sym_lr_cla, beta1=sym_b_c, beta2=.999, epsilon=1e-8) 409 | 410 | ######## avg 411 | avg_params = lasagne.layers.get_all_params(classifier) 412 | cla_param_avg=[] 413 | for param in avg_params: 414 | value = param.get_value(borrow=True) 415 | cla_param_avg.append(theano.shared(np.zeros(value.shape, dtype=value.dtype), 416 | broadcastable=param.broadcastable, 417 | name=param.name)) 418 | cla_avg_updates = [(a,a + 0.01*(p-a)) for p,a in zip(avg_params,cla_param_avg)] 419 | cla_avg_givens = [(p,a) for p,a in zip(avg_params, cla_param_avg)] 420 | cla_updates = cla_updates_.items() + bn_updates + cla_avg_updates 421 | cla_updates_fast = cla_updates_fast_.items()+ bn_updates + cla_avg_updates 422 | 423 | # functions 424 | train_batch_dis = theano.function(inputs=[sym_x_l, sym_y, sym_y_g, 425 | slice_x_u_c, slice_x_u_d, sym_lr], 426 | outputs=dis_cost_list, updates=dis_updates, 427 | givens={sym_x_u: shared_unlabel[slice_x_u_c], 428 | sym_x_u_d: shared_unlabel[slice_x_u_d]}) 429 | train_batch_gen = theano.function(inputs=[sym_y_g, sym_lr], 430 | outputs=gen_cost_list, updates=gen_updates) 431 | train_batch_cla = theano.function(inputs=[sym_x_l, sym_y, sym_y_g, slice_x_u_c, sym_alpha_cla_g, sym_lr_cla, sym_b_c, sym_unsup_weight], 432 | outputs=cla_cost_list , updates=cla_updates, 433 | givens={sym_x_u: shared_unlabel[slice_x_u_c], 434 | sym_x_u_rep: shared_unlabel[slice_x_u_c]}) 435 | # fast 436 | train_batch_cla_fast = theano.function(inputs=[sym_x_l, sym_y, slice_x_u_c, sym_lr_cla, sym_b_c, sym_unsup_weight], 437 | outputs=cla_cost_list_fast, updates=cla_updates_fast, 438 | givens={sym_x_u: shared_unlabel[slice_x_u_c], 439 | sym_x_u_rep: shared_unlabel[slice_x_u_c]}) 440 | 441 | generate = theano.function(inputs=[sym_y_g], outputs=image) 442 | # avg 443 | evaluate = theano.function(inputs=[sym_x_eval, sym_y], outputs=[accurracy_eval], givens=cla_avg_givens) 444 | 445 | ''' 446 | Load pretrained model 447 | ''' 448 | if 'oldmodel' in cfg: 449 | from utils.checkpoints import load_weights 450 | load_weights(cfg['oldmodel'], dis_layers+[classifier,]+gen_layers) 451 | for (p, a) in zip(ll.get_all_params(classifier), avg_params): 452 | a.set_value(p.get_value()) 453 | 454 | ''' 455 | train and evaluate 456 | ''' 457 | 458 | init_fn(x_unlabelled[:batch_size_u_c]) 459 | 460 | print 'Start training' 461 | for epoch in range(1, 1+num_epochs): 462 | start = time.time() 463 | 464 | # randomly permute data and labels 465 | p_l = rng.permutation(x_labelled.shape[0]) 466 | x_labelled = x_labelled[p_l] 467 | y_labelled = y_labelled[p_l] 468 | p_u = rng.permutation(x_unlabelled.shape[0]).astype('int32') 469 | p_u_d = rng.permutation(x_unlabelled.shape[0]).astype('int32') 470 | p_u_g = rng.permutation(x_unlabelled.shape[0]).astype('int32') 471 | 472 | dl = [0.] * len(dis_cost_list) 473 | gl = [0.] * len(gen_cost_list) 474 | cl = [0.] * len(cla_cost_list) 475 | 476 | # fast 477 | 478 | dis_t=0. 479 | gen_t=0. 480 | cla_t=0. 481 | 482 | # te 483 | rampup_value = rampup(epoch-1) 484 | rampdown_value = rampdown(epoch-1) 485 | lr_c = cla_lr 486 | b1_c = rampdown_value * 0.9 + (1.0 - rampdown_value) * 0.5 487 | unsup_weight = rampup_value * scaled_unsup_weight_max if epoch > 1 else 0. 488 | 489 | #print "@cla", lr_c, b1_c, unsup_weight 490 | 491 | for i in range(n_batches_train_u_c): 492 | from_u_c = i*batch_size_u_c 493 | to_u_c = (i+1)*batch_size_u_c 494 | i_c = i % n_batches_train_l_c 495 | from_l_c = i_c*batch_size_l_c 496 | to_l_c = (i_c+1)*batch_size_l_c 497 | i_d = i % n_batches_train_l_d 498 | from_l_d = i_d*batch_size_l_d 499 | to_l_d = (i_d+1)*batch_size_l_d 500 | i_d_ = i % n_batches_train_u_d 501 | from_u_d = i_d_*batch_size_u_d 502 | to_u_d = (i_d_+1)*batch_size_u_d 503 | 504 | sample_y = np.int32(np.repeat(np.arange(num_classes), batch_size_g/num_classes)) 505 | 506 | 507 | tmp = time.time() 508 | dl_b = train_batch_dis(x_labelled[from_l_d:to_l_d], y_labelled[from_l_d:to_l_d], sample_y, p_u[from_u_c:to_u_c], p_u_d[from_u_d:to_u_d], lr) 509 | for j in xrange(len(dl)): 510 | dl[j] += dl_b[j] 511 | 512 | tmp1 = time.time() 513 | 514 | #gl_b = train_batch_gen(sample_y, p_u_g[from_u_g:to_u_g], lr) 515 | gl_b = train_batch_gen(sample_y, lr) 516 | for j in xrange(len(gl)): 517 | gl[j] += gl_b[j] 518 | 519 | tmp2 = time.time() 520 | 521 | # fast 522 | if epoch < epoch_cla_g: 523 | cl_b = train_batch_cla_fast(x_labelled[from_l_c:to_l_c], y_labelled[from_l_c:to_l_c], p_u[from_u_c:to_u_c], lr_c, b1_c, unsup_weight) 524 | cl_b += [0,] 525 | else: 526 | cl_b = train_batch_cla(x_labelled[from_l_c:to_l_c], y_labelled[from_l_c:to_l_c], sample_y, p_u[from_u_c:to_u_c], alpha_cla_g, lr_c, b1_c, unsup_weight) 527 | for j in xrange(len(cl)): 528 | cl[j] += cl_b[j] 529 | 530 | tmp3 = time.time() 531 | dis_t+=(tmp1-tmp) 532 | gen_t+=(tmp2-tmp1) 533 | cla_t+=(tmp3-tmp2) 534 | 535 | print 'dis:', dis_t, 'gen:', gen_t, 'cla:', cla_t, 'total', dis_t+gen_t+cla_t 536 | 537 | for i in xrange(len(dl)): 538 | dl[i] /= n_batches_train_u_c 539 | for i in xrange(len(gl)): 540 | gl[i] /= n_batches_train_u_c 541 | for i in xrange(len(cl)): 542 | cl[i] /= n_batches_train_u_c 543 | 544 | if (epoch >= anneal_lr_epoch) and (epoch % anneal_lr_every_epoch == 0): 545 | lr = lr*anneal_lr_factor 546 | cla_lr *= anneal_lr_factor_cla 547 | 548 | t = time.time() - start 549 | 550 | line = "*Epoch=%d Time=%.2f LR=%.5f\n" %(epoch, t, lr) + "DisLosses: " + str(dl)+"\nGenLosses: "+str(gl)+"\nClaLosses: "+str(cl) 551 | 552 | print line 553 | with open(logfile,'a') as f: 554 | f.write(line + "\n") 555 | 556 | # random generation for visualization 557 | if epoch % vis_epoch == 0: 558 | import utils.paramgraphics as paramgraphics 559 | tail = '-'+str(epoch)+'.png' 560 | ran_y = np.int32(np.repeat(np.arange(num_classes), num_classes)) 561 | x_gen = generate(ran_y) 562 | x_gen = x_gen.reshape((z_generated*num_classes,-1)) 563 | image = paramgraphics.mat_to_img(x_gen.T, dim_input, colorImg=colorImg, scale=generation_scale, save_path=os.path.join(sample_path, 'sample'+tail)) 564 | 565 | if epoch % eval_epoch == 0: 566 | accurracy=[] 567 | for i in range(n_batches_eval): 568 | accurracy_batch = evaluate(eval_x[i*batch_size_eval:(i+1)*batch_size_eval], eval_y[i*batch_size_eval:(i+1)*batch_size_eval]) 569 | accurracy += accurracy_batch 570 | 571 | accurracy=np.mean(accurracy) 572 | print ('ErrorEval=%.5f\n' % (1-accurracy,)) 573 | with open(logfile,'a') as f: 574 | f.write(('ErrorEval=%.5f\n\n' % (1-accurracy,))) 575 | 576 | if epoch % 200 == 0 or (epoch == epoch_cla_g - 1): 577 | from utils.checkpoints import save_weights 578 | params = ll.get_all_params(dis_layers+[classifier,]+gen_layers) 579 | save_weights(os.path.join(outfolder, 'model_epoch' + str(epoch) + '.npy'), params, None) 580 | save_weights(os.path.join(outfolder, 'average'+ str(epoch) +'.npy'), cla_param_avg, None) 581 | -------------------------------------------------------------------------------- /x2y_yz2x_xy2p_ssl_mnist.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code implements a triple GAN for semi-supervised learning on MNIST 3 | ''' 4 | 5 | import os, time, argparse, shutil 6 | 7 | import numpy as np 8 | import theano, lasagne 9 | import theano.tensor as T 10 | import lasagne.layers as ll 11 | import lasagne.nonlinearities as ln 12 | import nn 13 | from lasagne.init import Normal 14 | from theano.sandbox.rng_mrg import MRG_RandomStreams 15 | from parmesan.datasets import load_mnist_realval 16 | 17 | from layers.merge import MLPConcatLayer 18 | 19 | from components.shortcuts import convlayer 20 | from components.objectives import categorical_crossentropy_ssl_separated, categorical_crossentropy 21 | import utils.paramgraphics as paramgraphics 22 | 23 | 24 | ''' 25 | parameters 26 | ''' 27 | # global 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("-key", type=str, default=argparse.SUPPRESS) 30 | parser.add_argument("-ssl_seed", type=int, default=1) 31 | parser.add_argument("-nlabeled", type=int, default=100) 32 | parser.add_argument("-objective_flag", type=str, default='argmax') 33 | parser.add_argument("-oldmodel", type=str, default=argparse.SUPPRESS) 34 | args = parser.parse_args() 35 | args = vars(args).items() 36 | cfg = {} 37 | for name, val in args: 38 | cfg[name] = val 39 | if cfg['ssl_seed'] == -1: 40 | cfg['ssl_seed'] = int(time.time()) 41 | 42 | filename_script=os.path.basename(os.path.realpath(__file__)) 43 | outfolder=os.path.join("results-ssl", os.path.splitext(filename_script)[0]) 44 | outfolder+='.' 45 | for item in cfg: 46 | if item is not 'oldmodel': 47 | outfolder += item+str(cfg[item])+'.' 48 | else: 49 | outfolder += 'oldmodel.' 50 | outfolder+=str(int(time.time())) 51 | if not os.path.exists(outfolder): 52 | os.makedirs(outfolder) 53 | sample_path = os.path.join(outfolder, 'sample') 54 | os.makedirs(sample_path) 55 | logfile=os.path.join(outfolder, 'logfile.log') 56 | shutil.copy(os.path.realpath(__file__), os.path.join(outfolder, filename_script)) 57 | # fixed random seeds 58 | ssl_data_seed=cfg['ssl_seed'] 59 | num_labelled=cfg['nlabeled'] 60 | print ssl_data_seed, num_labelled 61 | 62 | seed=1234 63 | rng=np.random.RandomState(seed) 64 | theano_rng=MRG_RandomStreams(rng.randint(2 ** 15)) 65 | lasagne.random.set_rng(np.random.RandomState(rng.randint(2 ** 15))) 66 | 67 | # dataset 68 | data_dir='/home/chongxuan/mfs/data/mnist_real/mnist.pkl.gz' 69 | # flags 70 | valid_flag=False 71 | objective_flag = cfg['objective_flag'] # integrate y, argmax 72 | 73 | # pre-train C 74 | pre_num_epoch = 0 if num_labelled > 100 else 30 75 | pre_alpha_unlabeled_entropy=.3 76 | pre_alpha_average=.3 77 | pre_lr=3e-4 78 | pre_batch_size_lc=min(100, num_labelled) 79 | pre_batch_size_uc=500 80 | # C 81 | alpha_decay=1e-4 82 | alpha_labeled=1. 83 | alpha_unlabeled_entropy=.3 84 | alpha_average=1e-3 85 | alpha_cla=1. 86 | # G 87 | n_z=100 88 | alpha_cla_g=.1 89 | epoch_cla_g=300 90 | # D 91 | noise_D_data=.3 92 | noise_D=.5 93 | # optimization 94 | b1_g=.5 # mom1 in Adam 95 | b1_d=.5 96 | b1_c=.5 97 | 98 | # adjust batch size for different number of labeled data 99 | batch_size_g=200 100 | batch_size_l_c=min(100, num_labelled) 101 | batch_size_u_c=max(100, 10000/num_labelled) 102 | batch_size_u_d=400 103 | batch_size_l_d=max(num_labelled/100, 1) 104 | lr=1e-3 105 | num_epochs=1000 106 | anneal_lr_epoch=300 107 | anneal_lr_every_epoch=1 108 | anneal_lr_factor=.995 109 | # data dependent 110 | gen_final_non=ln.sigmoid 111 | num_classes=10 112 | dim_input=(28,28) 113 | in_channels=1 114 | colorImg=False 115 | generation_scale=False 116 | z_generated=num_classes 117 | # evaluation 118 | vis_epoch=10 119 | eval_epoch=1 120 | 121 | 122 | ''' 123 | data 124 | ''' 125 | train_x, train_y, valid_x, valid_y, eval_x, eval_y = load_mnist_realval(data_dir) 126 | if valid_flag: 127 | eval_x = valid_x 128 | eval_y = valid_y 129 | else: 130 | train_x = np.concatenate([train_x, valid_x]) 131 | train_y = np.hstack((train_y, valid_y)) 132 | train_y = np.int32(train_y) 133 | eval_y = np.int32(eval_y) 134 | train_x = train_x.astype('float32') 135 | eval_x = eval_x.astype('float32') 136 | x_unlabelled = train_x.copy() 137 | 138 | rng_data = np.random.RandomState(ssl_data_seed) 139 | inds = rng_data.permutation(train_x.shape[0]) 140 | train_x = train_x[inds] 141 | train_y = train_y[inds] 142 | x_labelled = [] 143 | y_labelled = [] 144 | for j in range(num_classes): 145 | x_labelled.append(train_x[train_y==j][:num_labelled/num_classes]) 146 | y_labelled.append(train_y[train_y==j][:num_labelled/num_classes]) 147 | x_labelled = np.concatenate(x_labelled, axis=0) 148 | y_labelled = np.concatenate(y_labelled, axis=0) 149 | del train_x 150 | 151 | if True: 152 | print 'Size of training data', x_labelled.shape[0], x_unlabelled.shape[0] 153 | y_order = np.argsort(y_labelled) 154 | _x_mean = x_labelled[y_order] 155 | image = paramgraphics.mat_to_img(_x_mean.T, dim_input, tile_shape=(num_classes, num_labelled/num_classes),colorImg=colorImg, scale=generation_scale, save_path=os.path.join(outfolder, 'x_l_'+str(ssl_data_seed)+'_triple-gan.png')) 156 | 157 | pretrain_batches_train_uc = x_unlabelled.shape[0] / pre_batch_size_uc 158 | pretrain_batches_train_lc = x_labelled.shape[0] / pre_batch_size_lc 159 | n_batches_train_u_c = x_unlabelled.shape[0] / batch_size_u_c 160 | n_batches_train_l_c = x_labelled.shape[0] / batch_size_l_c 161 | n_batches_train_u_d = x_unlabelled.shape[0] / batch_size_u_d 162 | n_batches_train_l_d = x_labelled.shape[0] / batch_size_l_d 163 | n_batches_train_g = x_unlabelled.shape[0] / batch_size_g 164 | 165 | 166 | ''' 167 | models 168 | ''' 169 | # symbols 170 | sym_z_image = T.tile(theano_rng.uniform((z_generated, n_z)), (num_classes, 1)) 171 | sym_z_rand = theano_rng.uniform(size=(batch_size_g, n_z)) 172 | sym_x_u = T.matrix() 173 | sym_x_u_d = T.matrix() 174 | sym_x_u_g = T.matrix() 175 | sym_x_l = T.matrix() 176 | sym_y = T.ivector() 177 | sym_y_g = T.ivector() 178 | sym_x_eval = T.matrix() 179 | sym_lr = T.scalar() 180 | sym_alpha_cla_g = T.scalar() 181 | sym_alpha_unlabel_entropy = T.scalar() 182 | sym_alpha_unlabel_average = T.scalar() 183 | 184 | shared_unlabel = theano.shared(x_unlabelled, borrow=True) 185 | slice_x_u_g = T.ivector() 186 | slice_x_u_d = T.ivector() 187 | slice_x_u_c = T.ivector() 188 | 189 | # classifier x2y: p_c(x, y) = p(x) p_c(y | x) 190 | cla_in_x = ll.InputLayer(shape=(None, 28**2)) 191 | cla_layers = [cla_in_x] 192 | cla_layers.append(ll.ReshapeLayer(cla_layers[-1], (-1,1,28,28))) 193 | cla_layers.append(convlayer(l=cla_layers[-1], bn=True, dr=0.5, ps=2, n_kerns=32, d_kerns=(5,5), pad='valid', stride=1, W=Normal(0.05), nonlinearity=ln.rectify, name='cla-1')) 194 | cla_layers.append(convlayer(l=cla_layers[-1], bn=True, dr=0, ps=1, n_kerns=64, d_kerns=(3,3), pad='same', stride=1, W=Normal(0.05), nonlinearity=ln.rectify, name='cla-2')) 195 | cla_layers.append(convlayer(l=cla_layers[-1], bn=True, dr=0.5, ps=2, n_kerns=64, d_kerns=(3,3), pad='valid', stride=1, W=Normal(0.05), nonlinearity=ln.rectify, name='cla-3')) 196 | cla_layers.append(convlayer(l=cla_layers[-1], bn=True, dr=0, ps=1, n_kerns=128, d_kerns=(3,3), pad='same', stride=1, W=Normal(0.05), nonlinearity=ln.rectify, name='cla-4')) 197 | cla_layers.append(convlayer(l=cla_layers[-1], bn=True, dr=0, ps=1, n_kerns=128, d_kerns=(3,3), pad='same', stride=1, W=Normal(0.05), nonlinearity=ln.rectify, name='cla-5')) 198 | cla_layers.append(ll.GlobalPoolLayer(cla_layers[-1])) 199 | cla_layers.append(ll.DenseLayer(cla_layers[-1], num_units=num_classes, W=lasagne.init.Normal(1e-2, 0), nonlinearity=ln.softmax, name='cla-6')) 200 | classifier = cla_layers[-1] 201 | 202 | # generator y2x: p_g(x, y) = p(y) p_g(x | y) where x = G(z, y), z follows p_g(z) 203 | gen_in_z = ll.InputLayer(shape=(None, n_z)) 204 | gen_in_y = ll.InputLayer(shape=(None,)) 205 | gen_layers = [gen_in_z] 206 | gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-1')) 207 | gen_layers.append(ll.batch_norm(ll.DenseLayer(gen_layers[-1], num_units=500, nonlinearity=ln.softplus, name='gen-2'), name='gen-3')) 208 | gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-4')) 209 | gen_layers.append(ll.batch_norm(ll.DenseLayer(gen_layers[-1], num_units=500, nonlinearity=ln.softplus, name='gen-5'), name='gen-6')) 210 | gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-7')) 211 | gen_layers.append(nn.l2normalize(ll.DenseLayer(gen_layers[-1], num_units=28**2, nonlinearity=gen_final_non, name='gen-8'))) 212 | 213 | # discriminator xy2p: test a pair of input comes from p(x, y) instead of p_c or p_g 214 | dis_in_x = ll.InputLayer(shape=(None, 28**2)) 215 | dis_in_y = ll.InputLayer(shape=(None,)) 216 | dis_layers = [dis_in_x] 217 | dis_layers.append(nn.GaussianNoiseLayer(dis_layers[-1], sigma=noise_D_data, name='dis-1')) 218 | dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-2')) 219 | dis_layers.append(nn.DenseLayer(dis_layers[-1], num_units=1000, name='dis-3')) 220 | dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-4')) 221 | dis_layers.append(nn.GaussianNoiseLayer(dis_layers[-1], sigma=noise_D, name='dis-5')) 222 | dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-6')) 223 | dis_layers.append(nn.DenseLayer(dis_layers[-1], num_units=500, name='dis-7')) 224 | dis_layers.append(nn.GaussianNoiseLayer(dis_layers[-1], sigma=noise_D, name='dis-8')) 225 | dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-9')) 226 | dis_layers.append(nn.DenseLayer(dis_layers[-1], num_units=250, name='dis-10')) 227 | dis_layers.append(nn.GaussianNoiseLayer(dis_layers[-1], sigma=noise_D, name='dis-11')) 228 | dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-12')) 229 | dis_layers.append(nn.DenseLayer(dis_layers[-1], num_units=250, name='dis-13')) 230 | dis_layers.append(nn.GaussianNoiseLayer(dis_layers[-1], sigma=noise_D, name='dis-14')) 231 | dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-15')) 232 | dis_layers.append(nn.DenseLayer(dis_layers[-1], num_units=250, name='dis-16')) 233 | dis_layers.append(nn.GaussianNoiseLayer(dis_layers[-1], sigma=noise_D, name='dis-17')) 234 | dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-18')) 235 | dis_layers.append(nn.DenseLayer(dis_layers[-1], num_units=1, nonlinearity=ln.sigmoid, name='dis-19')) 236 | 237 | 238 | ''' 239 | objectives 240 | ''' 241 | # outputs 242 | gen_out_x = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_rand}, deterministic=False) 243 | 244 | cla_out_y_l = ll.get_output(cla_layers[-1], sym_x_l, deterministic=False) 245 | cla_out_y_eval = ll.get_output(cla_layers[-1], sym_x_eval, deterministic=True) 246 | cla_out_y = ll.get_output(cla_layers[-1], sym_x_u, deterministic=False) 247 | cla_out_y_d = ll.get_output(cla_layers[-1], {cla_in_x:sym_x_u_d}, deterministic=False) 248 | cla_out_y_d_hard = cla_out_y_d.argmax(axis=1) 249 | cla_out_y_g = ll.get_output(cla_layers[-1], {cla_in_x:gen_out_x}, deterministic=False) 250 | 251 | dis_out_p = ll.get_output(dis_layers[-1], {dis_in_x:T.concatenate([sym_x_l,sym_x_u_d], axis=0),dis_in_y:T.concatenate([sym_y,cla_out_y_d_hard], axis=0)}, deterministic=False) 252 | dis_out_p_g = ll.get_output(dis_layers[-1], {dis_in_x:gen_out_x,dis_in_y:sym_y_g}, deterministic=False) 253 | 254 | if objective_flag == 'integrate': 255 | # integrate 256 | dis_out_p_c = ll.get_output(dis_layers[-1], 257 | {dis_in_x:T.repeat(sym_x_u, num_classes, axis=0), 258 | dis_in_y:np.tile(np.arange(num_classes), batch_size_u_c)}, 259 | deterministic=False) 260 | elif objective_flag == 'argmax': 261 | # argmax approximation 262 | cla_out_y_hard = cla_out_y.argmax(axis=1) 263 | dis_out_p_c = ll.get_output(dis_layers[-1], {dis_in_x:sym_x_u,dis_in_y:cla_out_y_hard}, deterministic=False) 264 | else: 265 | raise Exception('Unknown objective flags') 266 | 267 | image = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_image}, deterministic=False) # for generation 268 | 269 | accurracy_eval = (lasagne.objectives.categorical_accuracy(cla_out_y_eval, sym_y)) # for evaluation 270 | accurracy_eval = accurracy_eval.mean() 271 | 272 | # costs 273 | bce = lasagne.objectives.binary_crossentropy 274 | 275 | dis_cost_p = bce(dis_out_p, T.ones(dis_out_p.shape)).mean() # D distincts p 276 | dis_cost_p_g = bce(dis_out_p_g, T.zeros(dis_out_p_g.shape)).mean() # D distincts p_g 277 | gen_cost_p_g = bce(dis_out_p_g, T.ones(dis_out_p_g.shape)).mean() # G fools D 278 | 279 | weight_decay_classifier = lasagne.regularization.regularize_layer_params_weighted({cla_layers[-1]:1}, lasagne.regularization.l2) # weight decay 280 | 281 | dis_cost_p_c = bce(dis_out_p_c, T.zeros(dis_out_p_c.shape)) # D distincts p_c 282 | cla_cost_p_c = bce(dis_out_p_c, T.ones(dis_out_p_c.shape)) # C fools D 283 | 284 | if objective_flag == 'integrate': 285 | # integrate 286 | weight_loss_c = T.reshape(cla_cost_p_c, (-1, num_classes)) * cla_out_y 287 | cla_cost_p_c = T.sum(weight_loss_c, axis=1).mean() 288 | weight_loss_d = T.reshape(dis_cost_p_c, (-1, num_classes)) * cla_out_y 289 | dis_cost_p_c = T.sum(weight_loss_d, axis=1).mean() 290 | elif objective_flag == 'argmax': 291 | # argmax approximation 292 | p = cla_out_y.max(axis=1) 293 | cla_cost_p_c = (cla_cost_p_c*p).mean() 294 | dis_cost_p_c = dis_cost_p_c.mean() 295 | 296 | cla_cost_cla = categorical_crossentropy_ssl_separated(predictions_l=cla_out_y_l, targets=sym_y, predictions_u=cla_out_y, weight_decay=weight_decay_classifier, alpha_labeled=alpha_labeled, alpha_unlabeled=sym_alpha_unlabel_entropy, alpha_average=sym_alpha_unlabel_average, alpha_decay=alpha_decay) # classification loss 297 | 298 | pretrain_cla_loss = categorical_crossentropy_ssl_separated(predictions_l=cla_out_y_l, targets=sym_y, predictions_u=cla_out_y, weight_decay=weight_decay_classifier, alpha_labeled=alpha_labeled, alpha_unlabeled=pre_alpha_unlabeled_entropy, alpha_average=pre_alpha_average, alpha_decay=alpha_decay) # classification loss 299 | pretrain_cost = pretrain_cla_loss 300 | 301 | cla_cost_cla_g = categorical_crossentropy(predictions=cla_out_y_g, targets=sym_y_g) 302 | 303 | dis_cost = dis_cost_p + .5*dis_cost_p_g + .5*dis_cost_p_c 304 | gen_cost = .5*gen_cost_p_g 305 | # flag 306 | cla_cost = .5*cla_cost_p_c + alpha_cla*(cla_cost_cla + sym_alpha_cla_g*cla_cost_cla_g) 307 | # fast 308 | cla_cost_fast = .5*cla_cost_p_c + alpha_cla*cla_cost_cla 309 | 310 | dis_cost_list=[dis_cost, dis_cost_p, .5*dis_cost_p_g, .5*dis_cost_p_c] 311 | gen_cost_list=[gen_cost] 312 | # flag 313 | cla_cost_list=[cla_cost, .5*cla_cost_p_c, alpha_cla*cla_cost_cla, alpha_cla*sym_alpha_cla_g*cla_cost_cla_g] 314 | # fast 315 | cla_cost_list_fast=[cla_cost_fast, .5*cla_cost_p_c, alpha_cla*cla_cost_cla] 316 | 317 | # updates of D 318 | dis_params = ll.get_all_params(dis_layers, trainable=True) 319 | dis_grads = T.grad(dis_cost, dis_params) 320 | dis_updates = lasagne.updates.adam(dis_grads, dis_params, beta1=b1_d, learning_rate=sym_lr) 321 | 322 | # updates of G 323 | gen_params = ll.get_all_params(gen_layers, trainable=True) 324 | gen_grads = T.grad(gen_cost, gen_params) 325 | gen_updates = lasagne.updates.adam(gen_grads, gen_params, beta1=b1_g, learning_rate=sym_lr) 326 | 327 | # updates of C 328 | cla_params = ll.get_all_params(cla_layers, trainable=True) 329 | cla_grads = T.grad(cla_cost, cla_params) 330 | cla_updates_ = lasagne.updates.adam(cla_grads, cla_params, beta1=b1_c, learning_rate=sym_lr) 331 | 332 | # fast updates of C 333 | cla_params = ll.get_all_params(cla_layers, trainable=True) 334 | cla_grads_fast = T.grad(cla_cost_fast, cla_params) 335 | cla_updates_fast_ = lasagne.updates.adam(cla_grads_fast, cla_params, beta1=b1_c, learning_rate=sym_lr) 336 | 337 | pre_cla_grad = T.grad(pretrain_cost, cla_params) 338 | pretrain_updates_ = lasagne.updates.adam(pre_cla_grad, cla_params, beta1=0.9, beta2=0.999, 339 | epsilon=1e-8, learning_rate=pre_lr) 340 | 341 | ######## avg 342 | avg_params = lasagne.layers.get_all_params(cla_layers) 343 | cla_param_avg=[] 344 | for param in avg_params: 345 | value = param.get_value(borrow=True) 346 | cla_param_avg.append(theano.shared(np.zeros(value.shape, dtype=value.dtype), 347 | broadcastable=param.broadcastable, 348 | name=param.name)) 349 | cla_avg_updates = [(a,a + 0.01*(p-a)) for p,a in zip(avg_params,cla_param_avg)] 350 | cla_avg_givens = [(p,a) for p,a in zip(avg_params, cla_param_avg)] 351 | cla_updates = cla_updates_.items() + cla_avg_updates 352 | cla_updates_fast = cla_updates_fast_.items() + cla_avg_updates 353 | pretrain_updates = pretrain_updates_.items() + cla_avg_updates 354 | 355 | # functions 356 | train_batch_dis = theano.function(inputs=[sym_x_l, sym_y, sym_y_g, 357 | slice_x_u_c, slice_x_u_d, sym_lr], 358 | outputs=dis_cost_list, updates=dis_updates, 359 | givens={sym_x_u: shared_unlabel[slice_x_u_c], 360 | sym_x_u_d: shared_unlabel[slice_x_u_d]}) 361 | train_batch_gen = theano.function(inputs=[sym_y_g, sym_lr], 362 | outputs=gen_cost_list, updates=gen_updates) 363 | train_batch_cla = theano.function(inputs=[sym_x_l, sym_y, sym_y_g, slice_x_u_c, sym_alpha_cla_g, sym_lr, sym_alpha_unlabel_entropy, sym_alpha_unlabel_average], 364 | outputs=cla_cost_list, updates=cla_updates, 365 | givens={sym_x_u: shared_unlabel[slice_x_u_c]}) 366 | # fast 367 | train_batch_cla_fast = theano.function(inputs=[sym_x_l, sym_y, slice_x_u_c, sym_lr, sym_alpha_unlabel_entropy, sym_alpha_unlabel_average], 368 | outputs=cla_cost_list_fast, updates=cla_updates_fast, 369 | givens={sym_x_u: shared_unlabel[slice_x_u_c]}) 370 | 371 | sym_index = T.iscalar() 372 | bslice = slice(sym_index*pre_batch_size_uc, (sym_index+1)*pre_batch_size_uc) 373 | pretrain_batch_cla = theano.function(inputs=[sym_x_l, sym_y, sym_index], 374 | outputs=[pretrain_cost], updates=pretrain_updates, 375 | givens={sym_x_u: shared_unlabel[bslice]}) 376 | generate = theano.function(inputs=[sym_y_g], outputs=image) 377 | 378 | # avg 379 | evaluate = theano.function(inputs=[sym_x_eval, sym_y], outputs=[accurracy_eval], givens=cla_avg_givens) 380 | 381 | 382 | ''' 383 | Load pretrained model 384 | ''' 385 | if 'oldmodel' in cfg: 386 | from utils.checkpoints import load_weights 387 | load_weights(cfg['oldmodel'], cla_layers) 388 | for (p, a) in zip(ll.get_all_params(cla_layers), avg_params): 389 | a.set_value(p.get_value()) 390 | 391 | 392 | ''' 393 | Pretrain C 394 | ''' 395 | for epoch in range(1, 1+pre_num_epoch): 396 | # randomly permute data and labels 397 | p_l = rng.permutation(x_labelled.shape[0]) 398 | x_labelled = x_labelled[p_l] 399 | y_labelled = y_labelled[p_l] 400 | p_u = rng.permutation(x_unlabelled.shape[0]).astype('int32') 401 | shared_unlabel.set_value(x_unlabelled[p_u]) 402 | 403 | for i in range(pretrain_batches_train_uc): 404 | i_c = i % pretrain_batches_train_lc 405 | from_l_c = i_c*pre_batch_size_lc 406 | to_l_c = (i_c+1)*pre_batch_size_lc 407 | 408 | pretrain_batch_cla(x_labelled[from_l_c:to_l_c], y_labelled[from_l_c:to_l_c], i) 409 | acc = evaluate(eval_x, eval_y) 410 | acc = acc[0] 411 | print str(epoch) + ':Pretrain accuracy: ' + str(1-acc) 412 | 413 | 414 | ''' 415 | train and evaluate 416 | ''' 417 | for epoch in range(1, 1+num_epochs): 418 | start = time.time() 419 | 420 | # randomly permute data and labels 421 | p_l = rng.permutation(x_labelled.shape[0]) 422 | x_labelled = x_labelled[p_l] 423 | y_labelled = y_labelled[p_l] 424 | p_u = rng.permutation(x_unlabelled.shape[0]).astype('int32') 425 | p_u_d = rng.permutation(x_unlabelled.shape[0]).astype('int32') 426 | p_u_g = rng.permutation(x_unlabelled.shape[0]).astype('int32') 427 | 428 | dl = [0.] * len(dis_cost_list) 429 | gl = [0.] * len(gen_cost_list) 430 | cl = [0.] * len(cla_cost_list) 431 | 432 | # fast 433 | 434 | for i in range(n_batches_train_u_c): 435 | from_u_c = i*batch_size_u_c 436 | to_u_c = (i+1)*batch_size_u_c 437 | i_c = i % n_batches_train_l_c 438 | from_l_c = i_c*batch_size_l_c 439 | to_l_c = (i_c+1)*batch_size_l_c 440 | i_d = i % n_batches_train_l_d 441 | from_l_d = i_d*batch_size_l_d 442 | to_l_d = (i_d+1)*batch_size_l_d 443 | i_d_ = i % n_batches_train_u_d 444 | from_u_d = i_d_*batch_size_u_d 445 | to_u_d = (i_d_+1)*batch_size_u_d 446 | 447 | sample_y = np.int32(np.repeat(np.arange(num_classes), batch_size_g/num_classes)) 448 | 449 | dl_b = train_batch_dis(x_labelled[from_l_d:to_l_d], y_labelled[from_l_d:to_l_d], sample_y, p_u[from_u_c:to_u_c], p_u_d[from_u_d:to_u_d], lr) 450 | for j in xrange(len(dl)): 451 | dl[j] += dl_b[j] 452 | 453 | gl_b = train_batch_gen(sample_y, lr) 454 | for j in xrange(len(gl)): 455 | gl[j] += gl_b[j] 456 | 457 | # fast 458 | if epoch < epoch_cla_g: 459 | cl_b = train_batch_cla_fast(x_labelled[from_l_c:to_l_c], y_labelled[from_l_c:to_l_c], p_u[from_u_c:to_u_c], lr, alpha_unlabeled_entropy, alpha_average) 460 | cl_b += [0,] 461 | else: 462 | cl_b = train_batch_cla(x_labelled[from_l_c:to_l_c], y_labelled[from_l_c:to_l_c], sample_y, p_u[from_u_c:to_u_c], alpha_cla_g, lr, alpha_unlabeled_entropy, alpha_average) 463 | 464 | for j in xrange(len(cl)): 465 | cl[j] += cl_b[j] 466 | 467 | for i in xrange(len(dl)): 468 | dl[i] /= n_batches_train_u_c 469 | for i in xrange(len(gl)): 470 | gl[i] /= n_batches_train_u_c 471 | for i in xrange(len(cl)): 472 | cl[i] /= n_batches_train_u_c 473 | 474 | if (epoch >= anneal_lr_epoch) and (epoch % anneal_lr_every_epoch == 0): 475 | lr = lr*anneal_lr_factor 476 | 477 | t = time.time() - start 478 | 479 | line = "*Epoch=%d Time=%.2f LR=%.5f\n" %(epoch, t, lr) + "DisLosses: " + str(dl)+"\nGenLosses: "+str(gl)+"\nClaLosses: "+str(cl) 480 | 481 | print line 482 | with open(logfile,'a') as f: 483 | f.write(line + "\n") 484 | 485 | # random generation for visualization 486 | if epoch % vis_epoch == 0: 487 | import utils.paramgraphics as paramgraphics 488 | tail = '-'+str(epoch)+'.png' 489 | ran_y = np.int32(np.repeat(np.arange(num_classes), num_classes)) 490 | x_gen = generate(ran_y) 491 | x_gen = x_gen.reshape((z_generated*num_classes,-1)) 492 | image = paramgraphics.mat_to_img(x_gen.T, dim_input, colorImg=colorImg, scale=generation_scale, save_path=os.path.join(sample_path, 'sample'+tail)) 493 | 494 | if epoch % eval_epoch == 0: 495 | acc = evaluate(eval_x, eval_y) 496 | acc = acc[0] 497 | print ('ErrorEval=%.5f\n' % (1-acc,)) 498 | with open(logfile,'a') as f: 499 | f.write(('ErrorEval=%.5f\n\n' % (1-acc,))) 500 | 501 | if epoch % 200 == 0 or epoch == epoch_cla_g-1: 502 | from utils.checkpoints import save_weights 503 | params = ll.get_all_params(dis_layers+cla_layers+gen_layers) 504 | save_weights(os.path.join(outfolder, 'model_epoch' + str(epoch) + '.npy'), params, None) 505 | 506 | save_weights(os.path.join(outfolder, 'average.npy'), cla_param_avg, None) 507 | -------------------------------------------------------------------------------- /x2y_yz2x_xy2p_ssl_svhn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code implements a triple GAN for semi-supervised learning on SVHN 3 | ''' 4 | import os, time, argparse, shutil 5 | 6 | import numpy as np 7 | import theano, lasagne 8 | import theano.tensor as T 9 | import lasagne.layers as ll 10 | import lasagne.nonlinearities as ln 11 | from lasagne.layers import dnn 12 | import nn 13 | from lasagne.init import Normal 14 | from theano.sandbox.rng_mrg import MRG_RandomStreams 15 | import svhn_data 16 | 17 | from layers.merge import ConvConcatLayer, MLPConcatLayer 18 | 19 | from components.objectives import categorical_crossentropy_ssl_separated, categorical_crossentropy 20 | import utils.paramgraphics as paramgraphics 21 | 22 | 23 | ''' 24 | parameters 25 | ''' 26 | # global 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("-key", type=str, default=argparse.SUPPRESS) 29 | parser.add_argument("-ssl_seed", type=int, default=1) 30 | parser.add_argument("-nlabeled", type=int, default=1000) 31 | parser.add_argument("-oldmodel", type=str, default=argparse.SUPPRESS) 32 | args = parser.parse_args() 33 | args = vars(args).items() 34 | cfg = {} 35 | for name, val in args: 36 | cfg[name] = val 37 | 38 | filename_script=os.path.basename(os.path.realpath(__file__)) 39 | outfolder=os.path.join("results-ssl", os.path.splitext(filename_script)[0]) 40 | outfolder+='.' 41 | for item in cfg: 42 | if item is not 'oldmodel': 43 | outfolder += item+str(cfg[item])+'.' 44 | else: 45 | outfolder += 'oldmodel.' 46 | outfolder+=str(int(time.time())) 47 | if not os.path.exists(outfolder): 48 | os.makedirs(outfolder) 49 | sample_path = os.path.join(outfolder, 'sample') 50 | os.makedirs(sample_path) 51 | logfile=os.path.join(outfolder, 'logfile.log') 52 | shutil.copy(os.path.realpath(__file__), os.path.join(outfolder, filename_script)) 53 | # fixed random seeds 54 | ssl_data_seed=cfg['ssl_seed'] 55 | num_labelled=cfg['nlabeled'] 56 | print ssl_data_seed, num_labelled 57 | 58 | seed=1234 59 | rng=np.random.RandomState(seed) 60 | theano_rng=MRG_RandomStreams(rng.randint(2 ** 15)) 61 | lasagne.random.set_rng(np.random.RandomState(rng.randint(2 ** 15))) 62 | 63 | # dataset 64 | data_dir='/home/chongxuan/mfs/data/svhn' 65 | 66 | # flags 67 | valid_flag=False 68 | objective_flag='argmax' # integrate y, choose argmax 69 | 70 | # pre-train C 71 | pre_num_epoch=0 72 | pre_alpha_unlabeled_entropy=.3 73 | pre_alpha_average=1e-3 74 | pre_lr=3e-4 75 | pre_batch_size_lc=200 76 | pre_batch_size_uc=200 77 | # C 78 | alpha_decay=1e-4 79 | alpha_labeled=1. 80 | alpha_unlabeled_entropy=.3 81 | alpha_average=1e-3 82 | alpha_cla=1. 83 | alpha_cla_adversarial=.01 84 | # G 85 | n_z=100 86 | alpha_cla_g=0.03 87 | epoch_cla_g=200 88 | # D 89 | noise_D_data=.3 90 | noise_D=.5 91 | # optimization 92 | b1_g=.5 # mom1 in Adam 93 | b1_d=.5 94 | b1_c=.5 95 | batch_size_g=200 96 | batch_size_l_c=200 97 | batch_size_u_c=200 98 | batch_size_u_d=180 99 | batch_size_l_d=20 100 | lr=3e-4 101 | num_epochs=1000 102 | anneal_lr_epoch=300 103 | anneal_lr_every_epoch=1 104 | anneal_lr_factor=.995 105 | # data dependent 106 | gen_final_non=ln.tanh 107 | num_classes=10 108 | dim_input=(32,32) 109 | in_channels=3 110 | colorImg=True 111 | generation_scale=True 112 | z_generated=num_classes 113 | # evaluation 114 | vis_epoch=10 115 | eval_epoch=1 116 | batch_size_eval=200 117 | 118 | 119 | ''' 120 | data 121 | ''' 122 | if valid_flag: 123 | def rescale(mat): 124 | return (((-127.5 + mat)/127.5).astype(theano.config.floatX)).reshape((-1, in_channels)+dim_input) 125 | train_x, train_y, valid_x, valid_y, eval_x, eval_y = load_svhn_small(data_dir, num_val=5000) 126 | eval_x = valid_x 127 | eval_y = valid_y 128 | else: 129 | def rescale(mat): 130 | return np.transpose(np.cast[theano.config.floatX]((-127.5 + mat)/127.5),(3,2,0,1)) 131 | train_x, train_y = svhn_data.load('/home/chongxuan/mfs/data/svhn/','train') 132 | eval_x, eval_y = svhn_data.load('/home/chongxuan/mfs/data/svhn/','test') 133 | 134 | train_y = np.int32(train_y) 135 | eval_y = np.int32(eval_y) 136 | train_x = rescale(train_x) 137 | eval_x = rescale(eval_x) 138 | x_unlabelled = train_x.copy() 139 | 140 | print train_x.shape, eval_x.shape 141 | 142 | rng_data = np.random.RandomState(ssl_data_seed) 143 | inds = rng_data.permutation(train_x.shape[0]) 144 | train_x = train_x[inds] 145 | train_y = train_y[inds] 146 | x_labelled = [] 147 | y_labelled = [] 148 | for j in range(num_classes): 149 | x_labelled.append(train_x[train_y==j][:num_labelled/num_classes]) 150 | y_labelled.append(train_y[train_y==j][:num_labelled/num_classes]) 151 | x_labelled = np.concatenate(x_labelled, axis=0) 152 | y_labelled = np.concatenate(y_labelled, axis=0) 153 | del train_x 154 | 155 | if True: 156 | print 'Size of training data', x_labelled.shape[0], x_unlabelled.shape[0] 157 | y_order = np.argsort(y_labelled) 158 | _x_mean = x_labelled[y_order] 159 | image = paramgraphics.mat_to_img(_x_mean.T, dim_input, tile_shape=(num_classes, num_labelled/num_classes), colorImg=colorImg, scale=generation_scale, save_path=os.path.join(outfolder, 'x_l_'+str(ssl_data_seed)+'_triple-gan.png')) 160 | 161 | pretrain_batches_train_uc = x_unlabelled.shape[0] / pre_batch_size_uc 162 | pretrain_batches_train_lc = x_labelled.shape[0] / pre_batch_size_lc 163 | n_batches_train_u_c = x_unlabelled.shape[0] / batch_size_u_c 164 | n_batches_train_l_c = x_labelled.shape[0] / batch_size_l_c 165 | n_batches_train_u_d = x_unlabelled.shape[0] / batch_size_u_d 166 | n_batches_train_l_d = x_labelled.shape[0] / batch_size_l_d 167 | n_batches_train_g = x_unlabelled.shape[0] / batch_size_g 168 | n_batches_eval = eval_x.shape[0] / batch_size_eval 169 | 170 | 171 | ''' 172 | models 173 | ''' 174 | # symbols 175 | sym_z_image = T.tile(theano_rng.uniform((z_generated, n_z)), (num_classes, 1)) 176 | sym_z_rand = theano_rng.uniform(size=(batch_size_g, n_z)) 177 | sym_x_u = T.tensor4() 178 | sym_x_u_d = T.tensor4() 179 | sym_x_u_g = T.tensor4() 180 | sym_x_l = T.tensor4() 181 | sym_y = T.ivector() 182 | sym_y_g = T.ivector() 183 | sym_x_eval = T.tensor4() 184 | sym_lr = T.scalar() 185 | sym_alpha_cla_g = T.scalar() 186 | sym_alpha_unlabel_entropy = T.scalar() 187 | sym_alpha_unlabel_average = T.scalar() 188 | 189 | shared_unlabel = theano.shared(x_unlabelled, borrow=True) 190 | slice_x_u_g = T.ivector() 191 | slice_x_u_d = T.ivector() 192 | slice_x_u_c = T.ivector() 193 | 194 | # classifier x2y: p_c(x, y) = p(x) p_c(y | x) 195 | cla_in_x = ll.InputLayer(shape=(None, in_channels) + dim_input) 196 | cla_layers = [cla_in_x] 197 | cla_layers.append(ll.DropoutLayer(cla_layers[-1], p=0.2, name='cla-00')) 198 | cla_layers.append(ll.batch_norm(dnn.Conv2DDNNLayer(cla_layers[-1], 128, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='cla-02'), name='cla-03')) 199 | cla_layers.append(ll.batch_norm(dnn.Conv2DDNNLayer(cla_layers[-1], 128, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='cla-11'), name='cla-12')) 200 | cla_layers.append(ll.batch_norm(dnn.Conv2DDNNLayer(cla_layers[-1], 128, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='cla-21'), name='cla-22')) 201 | cla_layers.append(dnn.MaxPool2DDNNLayer(cla_layers[-1], pool_size=(2, 2))) 202 | cla_layers.append(ll.DropoutLayer(cla_layers[-1], p=0.5, name='cla-23')) 203 | cla_layers.append(ll.batch_norm(dnn.Conv2DDNNLayer(cla_layers[-1], 256, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='cla-31'), name='cla-32')) 204 | cla_layers.append(ll.batch_norm(dnn.Conv2DDNNLayer(cla_layers[-1], 256, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='cla-41'), name='cla-42')) 205 | cla_layers.append(ll.batch_norm(dnn.Conv2DDNNLayer(cla_layers[-1], 256, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='cla-51'), name='cla-52')) 206 | cla_layers.append(dnn.MaxPool2DDNNLayer(cla_layers[-1], pool_size=(2, 2))) 207 | cla_layers.append(ll.DropoutLayer(cla_layers[-1], p=0.5, name='cla-53')) 208 | cla_layers.append(ll.batch_norm(dnn.Conv2DDNNLayer(cla_layers[-1], 512, (3,3), pad=0, W=Normal(0.05), nonlinearity=nn.lrelu, name='cla-61'), name='cla-62')) 209 | cla_layers.append(ll.batch_norm(ll.NINLayer(cla_layers[-1], num_units=256, W=Normal(0.05), nonlinearity=nn.lrelu, name='cla-71'), name='cla-72')) 210 | cla_layers.append(ll.batch_norm(ll.NINLayer(cla_layers[-1], num_units=128, W=Normal(0.05), nonlinearity=nn.lrelu, name='cla-81'), name='cla-82')) 211 | cla_layers.append(ll.GlobalPoolLayer(cla_layers[-1], name='cla-83')) 212 | cla_layers.append(ll.batch_norm(ll.DenseLayer(cla_layers[-1], num_units=num_classes, W=Normal(0.05), nonlinearity=ln.softmax, name='cla-91'), name='cla-92')) 213 | 214 | # generator y2x: p_g(x, y) = p(y) p_g(x | y) where x = G(z, y), z follows p_g(z) 215 | gen_in_z = ll.InputLayer(shape=(None, n_z)) 216 | gen_in_y = ll.InputLayer(shape=(None,)) 217 | gen_layers = [gen_in_z] 218 | gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-00')) 219 | gen_layers.append(nn.batch_norm(ll.DenseLayer(gen_layers[-1], num_units=4*4*512, W=Normal(0.05), nonlinearity=nn.relu, name='gen-01'), g=None, name='gen-02')) 220 | gen_layers.append(ll.ReshapeLayer(gen_layers[-1], (-1,512,4,4), name='gen-03')) 221 | gen_layers.append(ConvConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-10')) 222 | gen_layers.append(nn.batch_norm(nn.Deconv2DLayer(gen_layers[-1], (None,256,8,8), (5,5), W=Normal(0.05), nonlinearity=nn.relu, name='gen-11'), g=None, name='gen-12')) # 4 -> 8 223 | gen_layers.append(ConvConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-20')) 224 | gen_layers.append(nn.batch_norm(nn.Deconv2DLayer(gen_layers[-1], (None,128,16,16), (5,5), W=Normal(0.05), nonlinearity=nn.relu, name='gen-21'), g=None, name='gen-22')) # 8 -> 16 225 | gen_layers.append(ConvConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-30')) 226 | gen_layers.append(nn.weight_norm(nn.Deconv2DLayer(gen_layers[-1], (None,3,32,32), (5,5), W=Normal(0.05), nonlinearity=gen_final_non, name='gen-31'), train_g=True, init_stdv=0.1, name='gen-32')) # 16 -> 32 227 | 228 | # discriminator xy2p: test a pair of input comes from p(x, y) instead of p_c or p_g 229 | dis_in_x = ll.InputLayer(shape=(None, in_channels) + dim_input) 230 | dis_in_y = ll.InputLayer(shape=(None,)) 231 | dis_layers = [dis_in_x] 232 | dis_layers.append(ll.DropoutLayer(dis_layers[-1], p=0.2, name='dis-00')) 233 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-01')) 234 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 32, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-02'), name='dis-03')) 235 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-20')) 236 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 32, (3,3), pad=1, stride=2, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-21'), name='dis-22')) 237 | dis_layers.append(ll.DropoutLayer(dis_layers[-1], p=0.2, name='dis-23')) 238 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-30')) 239 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 64, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-31'), name='dis-32')) 240 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-40')) 241 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 64, (3,3), pad=1, stride=2, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-41'), name='dis-42')) 242 | dis_layers.append(ll.DropoutLayer(dis_layers[-1], p=0.2, name='dis-43')) 243 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-50')) 244 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 128, (3,3), pad=0, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-51'), name='dis-52')) 245 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-60')) 246 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 128, (3,3), pad=0, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-61'), name='dis-62')) 247 | dis_layers.append(ll.GlobalPoolLayer(dis_layers[-1], name='dis-63')) 248 | dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-70')) 249 | dis_layers.append(nn.weight_norm(ll.DenseLayer(dis_layers[-1], num_units=1, W=Normal(0.05), nonlinearity=ln.sigmoid, name='dis-71'), train_g=True, init_stdv=0.1, name='dis-72')) 250 | 251 | 252 | ''' 253 | objectives 254 | ''' 255 | # outputs 256 | gen_out_x = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_rand}, deterministic=False) 257 | 258 | cla_out_y_l = ll.get_output(cla_layers[-1], sym_x_l, deterministic=False) 259 | cla_out_y_eval = ll.get_output(cla_layers[-1], sym_x_eval, deterministic=True) 260 | cla_out_y = ll.get_output(cla_layers[-1], sym_x_u, deterministic=False) 261 | cla_out_y_d = ll.get_output(cla_layers[-1], {cla_in_x:sym_x_u_d}, deterministic=False) 262 | cla_out_y_d_hard = cla_out_y_d.argmax(axis=1) 263 | cla_out_y_g = ll.get_output(cla_layers[-1], {cla_in_x:gen_out_x}, deterministic=False) 264 | 265 | dis_out_p = ll.get_output(dis_layers[-1], {dis_in_x:T.concatenate([sym_x_l,sym_x_u_d], axis=0),dis_in_y:T.concatenate([sym_y,cla_out_y_d_hard], axis=0)}, deterministic=False) 266 | dis_out_p_g = ll.get_output(dis_layers[-1], {dis_in_x:gen_out_x,dis_in_y:sym_y_g}, deterministic=False) 267 | 268 | if objective_flag == 'integrate': 269 | # integrate 270 | dis_out_p_c = ll.get_output(dis_layers[-1], 271 | {dis_in_x:T.repeat(sym_x_u, num_classes, axis=0), 272 | dis_in_y:np.tile(np.arange(num_classes), batch_size_u_c)}, 273 | deterministic=False) 274 | elif objective_flag == 'argmax': 275 | # argmax approximation 276 | cla_out_y_hard = cla_out_y.argmax(axis=1) 277 | dis_out_p_c = ll.get_output(dis_layers[-1], {dis_in_x:sym_x_u,dis_in_y:cla_out_y_hard}, deterministic=False) 278 | else: 279 | raise Exception('Unknown objective flags') 280 | 281 | image = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_image}, deterministic=False) # for generation 282 | 283 | accurracy_eval = (lasagne.objectives.categorical_accuracy(cla_out_y_eval, sym_y)) # for evaluation 284 | accurracy_eval = accurracy_eval.mean() 285 | 286 | # costs 287 | bce = lasagne.objectives.binary_crossentropy 288 | 289 | dis_cost_p = bce(dis_out_p, T.ones(dis_out_p.shape)).mean() # D distincts p 290 | dis_cost_p_g = bce(dis_out_p_g, T.zeros(dis_out_p_g.shape)).mean() # D distincts p_g 291 | gen_cost_p_g = bce(dis_out_p_g, T.ones(dis_out_p_g.shape)).mean() # G fools D 292 | 293 | weight_decay_classifier = lasagne.regularization.regularize_layer_params_weighted({cla_layers[-1]:1}, lasagne.regularization.l2) # weight decay 294 | 295 | dis_cost_p_c = bce(dis_out_p_c, T.zeros(dis_out_p_c.shape)) # D distincts p_c 296 | cla_cost_p_c = bce(dis_out_p_c, T.ones(dis_out_p_c.shape)) # C fools D 297 | 298 | if objective_flag == 'integrate': 299 | # integrate 300 | weight_loss_c = T.reshape(cla_cost_p_c, (-1, num_classes)) * cla_out_y 301 | cla_cost_p_c = T.sum(weight_loss_c, axis=1).mean() 302 | weight_loss_d = T.reshape(dis_cost_p_c, (-1, num_classes)) * cla_out_y 303 | dis_cost_p_c = T.sum(weight_loss_d, axis=1).mean() 304 | elif objective_flag == 'argmax': 305 | # argmax approximation 306 | p = cla_out_y.max(axis=1) 307 | cla_cost_p_c = (cla_cost_p_c*p).mean() 308 | dis_cost_p_c = dis_cost_p_c.mean() 309 | 310 | cla_cost_cla = categorical_crossentropy_ssl_separated(predictions_l=cla_out_y_l, targets=sym_y, predictions_u=cla_out_y, weight_decay=weight_decay_classifier, alpha_labeled=alpha_labeled, alpha_unlabeled=sym_alpha_unlabel_entropy, alpha_average=sym_alpha_unlabel_average, alpha_decay=alpha_decay) # classification loss 311 | 312 | pretrain_cla_loss = categorical_crossentropy_ssl_separated(predictions_l=cla_out_y_l, targets=sym_y, predictions_u=cla_out_y, weight_decay=weight_decay_classifier, alpha_labeled=alpha_labeled, alpha_unlabeled=pre_alpha_unlabeled_entropy, alpha_average=pre_alpha_average, alpha_decay=alpha_decay) # classification loss 313 | pretrain_cost = pretrain_cla_loss 314 | 315 | cla_cost_cla_g = categorical_crossentropy(predictions=cla_out_y_g, targets=sym_y_g) 316 | 317 | dis_cost = dis_cost_p + .5*dis_cost_p_g + .5*dis_cost_p_c 318 | gen_cost = .5*gen_cost_p_g 319 | # flag 320 | cla_cost = alpha_cla_adversarial*.5*cla_cost_p_c + cla_cost_cla + sym_alpha_cla_g*cla_cost_cla_g 321 | # fast 322 | cla_cost_fast = alpha_cla_adversarial*.5*cla_cost_p_c + cla_cost_cla 323 | 324 | dis_cost_list=[dis_cost, dis_cost_p, .5*dis_cost_p_g, .5*dis_cost_p_c] 325 | gen_cost_list=[gen_cost,] 326 | # flag 327 | cla_cost_list=[cla_cost, alpha_cla_adversarial*.5*cla_cost_p_c, cla_cost_cla, sym_alpha_cla_g*cla_cost_cla_g] 328 | # fast 329 | cla_cost_list_fast=[cla_cost_fast, alpha_cla_adversarial*.5*cla_cost_p_c, cla_cost_cla] 330 | 331 | # updates of D 332 | dis_params = ll.get_all_params(dis_layers, trainable=True) 333 | dis_grads = T.grad(dis_cost, dis_params) 334 | dis_updates = lasagne.updates.adam(dis_grads, dis_params, beta1=b1_d, learning_rate=sym_lr) 335 | 336 | # updates of G 337 | gen_params = ll.get_all_params(gen_layers, trainable=True) 338 | gen_grads = T.grad(gen_cost, gen_params) 339 | gen_updates = lasagne.updates.adam(gen_grads, gen_params, beta1=b1_g, learning_rate=sym_lr) 340 | 341 | # updates of C 342 | cla_params = ll.get_all_params(cla_layers, trainable=True) 343 | cla_grads = T.grad(cla_cost, cla_params) 344 | cla_updates_ = lasagne.updates.adam(cla_grads, cla_params, beta1=b1_c, learning_rate=sym_lr) 345 | 346 | # fast updates of C 347 | cla_params = ll.get_all_params(cla_layers, trainable=True) 348 | cla_grads_fast = T.grad(cla_cost_fast, cla_params) 349 | cla_updates_fast_ = lasagne.updates.adam(cla_grads_fast, cla_params, beta1=b1_c, learning_rate=sym_lr) 350 | 351 | pre_cla_grad = T.grad(pretrain_cost, cla_params) 352 | pretrain_updates_ = lasagne.updates.adam(pre_cla_grad, cla_params, beta1=0.9, beta2=0.999, 353 | epsilon=1e-8, learning_rate=pre_lr) 354 | 355 | ######## avg 356 | avg_params = lasagne.layers.get_all_params(cla_layers) 357 | cla_param_avg=[] 358 | for param in avg_params: 359 | value = param.get_value(borrow=True) 360 | cla_param_avg.append(theano.shared(np.zeros(value.shape, dtype=value.dtype), 361 | broadcastable=param.broadcastable, 362 | name=param.name)) 363 | cla_avg_updates = [(a,a + 0.01*(p-a)) for p,a in zip(avg_params,cla_param_avg)] 364 | cla_avg_givens = [(p,a) for p,a in zip(avg_params, cla_param_avg)] 365 | cla_updates = cla_updates_.items() + cla_avg_updates 366 | cla_updates_fast = cla_updates_fast_.items() + cla_avg_updates 367 | pretrain_updates = pretrain_updates_.items() + cla_avg_updates 368 | 369 | # functions 370 | train_batch_dis = theano.function(inputs=[sym_x_l, sym_y, sym_y_g, 371 | slice_x_u_c, slice_x_u_d, sym_lr], 372 | outputs=dis_cost_list, updates=dis_updates, 373 | givens={sym_x_u: shared_unlabel[slice_x_u_c], 374 | sym_x_u_d: shared_unlabel[slice_x_u_d]}) 375 | train_batch_gen = theano.function(inputs=[sym_y_g, sym_lr], 376 | outputs=gen_cost_list, updates=gen_updates) 377 | train_batch_cla = theano.function(inputs=[sym_x_l, sym_y, sym_y_g, slice_x_u_c, sym_alpha_cla_g, sym_lr, sym_alpha_unlabel_entropy, sym_alpha_unlabel_average], 378 | outputs=cla_cost_list, updates=cla_updates, 379 | givens={sym_x_u: shared_unlabel[slice_x_u_c]}) 380 | # fast 381 | train_batch_cla_fast = theano.function(inputs=[sym_x_l, sym_y, slice_x_u_c, sym_lr, sym_alpha_unlabel_entropy, sym_alpha_unlabel_average], 382 | outputs=cla_cost_list_fast, updates=cla_updates_fast, 383 | givens={sym_x_u: shared_unlabel[slice_x_u_c]}) 384 | 385 | sym_index = T.iscalar() 386 | bslice = slice(sym_index*pre_batch_size_uc, (sym_index+1)*pre_batch_size_uc) 387 | pretrain_batch_cla = theano.function(inputs=[sym_x_l, sym_y, sym_index], 388 | outputs=[pretrain_cost], updates=pretrain_updates, 389 | givens={sym_x_u: shared_unlabel[bslice]}) 390 | generate = theano.function(inputs=[sym_y_g], outputs=image) 391 | # avg 392 | evaluate = theano.function(inputs=[sym_x_eval, sym_y], outputs=[accurracy_eval], givens=cla_avg_givens) 393 | 394 | ''' 395 | Load pretrained model 396 | ''' 397 | if 'oldmodel' in cfg: 398 | from utils.checkpoints import load_weights 399 | load_weights(cfg['oldmodel'], dis_layers+cla_layers+gen_layers) 400 | for (p, a) in zip(ll.get_all_params(cla_layers), avg_params): 401 | a.set_value(p.get_value()) 402 | 403 | 404 | ''' 405 | Pretrain C 406 | ''' 407 | for epoch in range(1, 1+pre_num_epoch): 408 | # randomly permute data and labels 409 | p_l = rng.permutation(x_labelled.shape[0]) 410 | x_labelled = x_labelled[p_l] 411 | y_labelled = y_labelled[p_l] 412 | p_u = rng.permutation(x_unlabelled.shape[0]).astype('int32') 413 | shared_unlabel.set_value(x_unlabelled[p_u]) 414 | 415 | for i in range(pretrain_batches_train_uc): 416 | i_c = i % pretrain_batches_train_lc 417 | from_l_c = i_c*pre_batch_size_lc 418 | to_l_c = (i_c+1)*pre_batch_size_lc 419 | 420 | pretrain_batch_cla(x_labelled[from_l_c:to_l_c], y_labelled[from_l_c:to_l_c], i) 421 | accurracy=[] 422 | for i in range(n_batches_eval): 423 | accurracy_batch = evaluate(eval_x[i*batch_size_eval:(i+1)*batch_size_eval], eval_y[i*batch_size_eval:(i+1)*batch_size_eval]) 424 | accurracy += accurracy_batch 425 | accurracy=np.mean(accurracy) 426 | print str(epoch) + ':Pretrain accuracy: ' + str(1-accurracy) 427 | 428 | ''' 429 | train and evaluate 430 | ''' 431 | print 'Start training' 432 | for epoch in range(1, 1+num_epochs): 433 | start = time.time() 434 | 435 | # randomly permute data and labels 436 | p_l = rng.permutation(x_labelled.shape[0]) 437 | x_labelled = x_labelled[p_l] 438 | y_labelled = y_labelled[p_l] 439 | p_u = rng.permutation(x_unlabelled.shape[0]).astype('int32') 440 | p_u_d = rng.permutation(x_unlabelled.shape[0]).astype('int32') 441 | p_u_g = rng.permutation(x_unlabelled.shape[0]).astype('int32') 442 | 443 | dl = [0.] * len(dis_cost_list) 444 | gl = [0.] * len(gen_cost_list) 445 | cl = [0.] * len(cla_cost_list) 446 | 447 | # fast 448 | dis_t=0. 449 | gen_t=0. 450 | cla_t=0. 451 | for i in range(n_batches_train_u_c): 452 | from_u_c = i*batch_size_u_c 453 | to_u_c = (i+1)*batch_size_u_c 454 | i_c = i % n_batches_train_l_c 455 | from_l_c = i_c*batch_size_l_c 456 | to_l_c = (i_c+1)*batch_size_l_c 457 | i_d = i % n_batches_train_l_d 458 | from_l_d = i_d*batch_size_l_d 459 | to_l_d = (i_d+1)*batch_size_l_d 460 | i_d_ = i % n_batches_train_u_d 461 | from_u_d = i_d_*batch_size_u_d 462 | to_u_d = (i_d_+1)*batch_size_u_d 463 | 464 | sample_y = np.int32(np.repeat(np.arange(num_classes), batch_size_g/num_classes)) 465 | 466 | tmp = time.time() 467 | dl_b = train_batch_dis(x_labelled[from_l_d:to_l_d], y_labelled[from_l_d:to_l_d], sample_y, p_u[from_u_c:to_u_c], p_u_d[from_u_d:to_u_d], lr) 468 | for j in xrange(len(dl)): 469 | dl[j] += dl_b[j] 470 | 471 | tmp1 = time.time() 472 | 473 | gl_b = train_batch_gen(sample_y, lr) 474 | for j in xrange(len(gl)): 475 | gl[j] += gl_b[j] 476 | 477 | tmp2 = time.time() 478 | 479 | # fast 480 | if epoch < epoch_cla_g: 481 | cl_b = train_batch_cla_fast(x_labelled[from_l_c:to_l_c], y_labelled[from_l_c:to_l_c], p_u[from_u_c:to_u_c], lr, alpha_unlabeled_entropy, alpha_average) 482 | cl_b += [0,] 483 | else: 484 | cl_b = train_batch_cla(x_labelled[from_l_c:to_l_c], y_labelled[from_l_c:to_l_c], sample_y, p_u[from_u_c:to_u_c], alpha_cla_g, lr, alpha_unlabeled_entropy, alpha_average) 485 | 486 | for j in xrange(len(cl)): 487 | cl[j] += cl_b[j] 488 | 489 | tmp3 = time.time() 490 | dis_t+=(tmp1-tmp) 491 | gen_t+=(tmp2-tmp1) 492 | cla_t+=(tmp3-tmp2) 493 | 494 | print 'dis:', dis_t, 'gen:', gen_t, 'cla:', cla_t, 'total', dis_t+gen_t+cla_t 495 | 496 | for i in xrange(len(dl)): 497 | dl[i] /= n_batches_train_u_c 498 | for i in xrange(len(gl)): 499 | gl[i] /= n_batches_train_u_c 500 | for i in xrange(len(cl)): 501 | cl[i] /= n_batches_train_u_c 502 | 503 | if (epoch >= anneal_lr_epoch) and (epoch % anneal_lr_every_epoch == 0): 504 | lr = lr*anneal_lr_factor 505 | 506 | t = time.time() - start 507 | 508 | line = "*Epoch=%d Time=%.2f LR=%.5f\n" %(epoch, t, lr) + "DisLosses: " + str(dl)+"\nGenLosses: "+str(gl)+"\nClaLosses: "+str(cl) 509 | 510 | print line 511 | with open(logfile,'a') as f: 512 | f.write(line + "\n") 513 | 514 | # random generation for visualization 515 | if epoch % vis_epoch == 0: 516 | import utils.paramgraphics as paramgraphics 517 | tail = '-'+str(epoch)+'.png' 518 | ran_y = np.int32(np.repeat(np.arange(num_classes), num_classes)) 519 | x_gen = generate(ran_y) 520 | x_gen = x_gen.reshape((z_generated*num_classes,-1)) 521 | image = paramgraphics.mat_to_img(x_gen.T, dim_input, colorImg=colorImg, scale=generation_scale, save_path=os.path.join(sample_path, 'sample'+tail)) 522 | 523 | if epoch % eval_epoch == 0: 524 | accurracy=[] 525 | for i in range(n_batches_eval): 526 | accurracy_batch = evaluate(eval_x[i*batch_size_eval:(i+1)*batch_size_eval], eval_y[i*batch_size_eval:(i+1)*batch_size_eval]) 527 | accurracy += accurracy_batch 528 | accurracy=np.mean(accurracy) 529 | print ('ErrorEval=%.5f\n' % (1-accurracy,)) 530 | with open(logfile,'a') as f: 531 | f.write(('ErrorEval=%.5f\n\n' % (1-accurracy,))) 532 | 533 | if epoch % 200 == 0 or (epoch == epoch_cla_g - 1): 534 | from utils.checkpoints import save_weights 535 | params = ll.get_all_params(dis_layers+cla_layers+gen_layers) 536 | save_weights(os.path.join(outfolder, 'model_epoch' + str(epoch) + '.npy'), params, None) 537 | save_weights(os.path.join(outfolder, 'average'+ str(epoch) +'.npy'), cla_param_avg, None) 538 | -------------------------------------------------------------------------------- /zca_bn.py: -------------------------------------------------------------------------------- 1 | # ZCA and MeanOnlyBNLayer implementations copied from 2 | # https://github.com/TimSalimans/weight_norm/blob/master/nn.py 3 | # 4 | # Modifications made to MeanOnlyBNLayer: 5 | # - Added configurable momentum. 6 | # - Added 'modify_incoming' flag for weight matrix sharing (not used in this project). 7 | # - Sums and means use float32 datatype. 8 | 9 | import numpy as np 10 | import theano as th 11 | import theano.tensor as T 12 | from scipy import linalg 13 | import lasagne 14 | 15 | class ZCA(object): 16 | def __init__(self, regularization=1e-5, x=None): 17 | self.regularization = regularization 18 | if x is not None: 19 | self.fit(x) 20 | 21 | def fit(self, x): 22 | s = x.shape 23 | x = x.copy().reshape((s[0],np.prod(s[1:]))) 24 | m = np.mean(x, axis=0) 25 | x -= m 26 | sigma = np.dot(x.T,x) / x.shape[0] 27 | U, S, V = linalg.svd(sigma) 28 | tmp = np.dot(U, np.diag(1./np.sqrt(S+self.regularization))) 29 | tmp2 = np.dot(U, np.diag(np.sqrt(S+self.regularization))) 30 | self.ZCA_mat = th.shared(np.dot(tmp, U.T).astype(th.config.floatX)) 31 | self.inv_ZCA_mat = th.shared(np.dot(tmp2, U.T).astype(th.config.floatX)) 32 | self.mean = th.shared(m.astype(th.config.floatX)) 33 | 34 | def apply(self, x): 35 | s = x.shape 36 | if isinstance(x, np.ndarray): 37 | return np.dot(x.reshape((s[0],np.prod(s[1:]))) - self.mean.get_value(), self.ZCA_mat.get_value()).reshape(s) 38 | elif isinstance(x, T.TensorVariable): 39 | return T.dot(x.flatten(2) - self.mean.dimshuffle('x',0), self.ZCA_mat).reshape(s) 40 | else: 41 | raise NotImplementedError("Whitening only implemented for numpy arrays or Theano TensorVariables") 42 | 43 | def invert(self, x): 44 | s = x.shape 45 | if isinstance(x, np.ndarray): 46 | return (np.dot(x.reshape((s[0],np.prod(s[1:]))), self.inv_ZCA_mat.get_value()) + self.mean.get_value()).reshape(s) 47 | elif isinstance(x, T.TensorVariable): 48 | return (T.dot(x.flatten(2), self.inv_ZCA_mat) + self.mean.dimshuffle('x',0)).reshape(s) 49 | else: 50 | raise NotImplementedError("Whitening only implemented for numpy arrays or Theano TensorVariables") 51 | 52 | # T.nnet.relu has some issues with very large inputs, this is more stable 53 | def relu(x): 54 | return T.maximum(x, 0) 55 | 56 | class MeanOnlyBNLayer(lasagne.layers.Layer): 57 | def __init__(self, incoming, b=lasagne.init.Constant(0.), g=lasagne.init.Constant(1.), 58 | W=lasagne.init.Normal(0.05), nonlinearity=relu, modify_incoming=True, momentum=0.9, **kwargs): 59 | super(MeanOnlyBNLayer, self).__init__(incoming, **kwargs) 60 | self.nonlinearity = nonlinearity 61 | self.momentum = momentum 62 | k = self.input_shape[1] 63 | if b is not None: 64 | self.b = self.add_param(b, (k,), name="b", regularizable=False) 65 | if g is not None: 66 | self.g = self.add_param(g, (k,), name="g") 67 | self.avg_batch_mean = self.add_param(lasagne.init.Constant(0.), (k,), name="avg_batch_mean", regularizable=False, trainable=False) 68 | if len(self.input_shape)==4: 69 | self.axes_to_sum = (0,2,3) 70 | self.dimshuffle_args = ['x',0,'x','x'] 71 | else: 72 | self.axes_to_sum = 0 73 | self.dimshuffle_args = ['x',0] 74 | 75 | # scale weights in layer below 76 | incoming.W_param = incoming.W 77 | if modify_incoming: 78 | incoming.W_param.set_value(W.sample(incoming.W_param.get_value().shape)) 79 | if incoming.W_param.ndim==4: 80 | W_axes_to_sum = (1,2,3) 81 | W_dimshuffle_args = [0,'x','x','x'] 82 | else: 83 | W_axes_to_sum = 0 84 | W_dimshuffle_args = ['x',0] 85 | if g is not None: 86 | incoming.W = incoming.W_param * (self.g/T.sqrt(T.sum(T.square(incoming.W_param),axis=W_axes_to_sum,dtype=th.config.floatX, acc_dtype=th.config.floatX))).dimshuffle(*W_dimshuffle_args) 87 | else: 88 | incoming.W = incoming.W_param / T.sqrt(T.sum(T.square(incoming.W_param),axis=W_axes_to_sum,keepdims=True,dtype=th.config.floatX,acc_dtype=th.config.floatX)) 89 | 90 | def get_output_for(self, input, deterministic=False, init=False, **kwargs): 91 | if deterministic: 92 | activation = input - self.avg_batch_mean.dimshuffle(*self.dimshuffle_args) 93 | else: 94 | m = T.mean(input,axis=self.axes_to_sum,dtype=th.config.floatX,acc_dtype=th.config.floatX) 95 | activation = input - m.dimshuffle(*self.dimshuffle_args) 96 | self.bn_updates = [(self.avg_batch_mean, self.momentum*self.avg_batch_mean + (1.0-self.momentum)*m)] 97 | if init: 98 | stdv = T.sqrt(T.mean(T.square(activation),axis=self.axes_to_sum,dtype=th.config.floatX,acc_dtype=th.config.floatX)) 99 | activation /= stdv.dimshuffle(*self.dimshuffle_args) 100 | self.init_updates = [(self.g, self.g/stdv)] 101 | if hasattr(self, 'b'): 102 | activation += self.b.dimshuffle(*self.dimshuffle_args) 103 | 104 | return self.nonlinearity(activation) 105 | 106 | def mean_only_bn(layer, **kwargs): 107 | nonlinearity = getattr(layer, 'nonlinearity', None) 108 | if nonlinearity is not None: 109 | layer.nonlinearity = lasagne.nonlinearities.identity 110 | if hasattr(layer, 'b') and layer.b is not None: 111 | del layer.params[layer.b] 112 | layer.b = None 113 | return MeanOnlyBNLayer(layer, name=layer.name+'_n', nonlinearity=nonlinearity, **kwargs) 114 | --------------------------------------------------------------------------------