├── README.md
├── cifar10_data.py
├── components
├── __init__.py
├── objectives.py
└── shortcuts.py
├── generate_given_models.py
├── images
├── class-0.png
├── class-1.png
├── class-2.png
├── class-7.png
├── mnist_fm.png
├── mnist_random.png
├── svhn_data.png
├── svhn_linear.png
└── svhn_share_z.png
├── layers
├── __init__.py
├── deconv.py
├── merge.py
└── sample.py
├── nn.py
├── svhn_data.py
├── utils
├── __init__.py
├── checkpoints.py
├── create_ssl_data.py
├── others.py
└── paramgraphics.py
├── x2y_yz2x_xy2p_ssl_cifar10.py
├── x2y_yz2x_xy2p_ssl_mnist.py
├── x2y_yz2x_xy2p_ssl_svhn.py
└── zca_bn.py
/README.md:
--------------------------------------------------------------------------------
1 | # Triple Generative Adversarial Nets (Triple-GAN)
2 | ## [Chongxuan Li](https://github.com/zhenxuan00), [Kun Xu](https://github.com/taufikxu), Jun Zhu and Bo Zhang
3 |
4 | Code for reproducing most of the results in the [paper](https://arxiv.org/abs/1703.02291). Triple-GAN: a unified GAN model for classification and class-conditional generation in semi-supervised learning.
5 |
6 | Warning: the code is still under development.
7 |
8 |
9 | ## Triple-GAN-V2 and code in Pytorch!
10 |
11 | We propose Triple-GAN-V2 built upon mean teacher classifier and projection discriminator with spectral norm and implement Triple-GAN in Pytorch. See the source code at https://github.com/taufikxu/Triple-GAN
12 |
13 |
14 | ## Envoronment settings and libs we used in our experiments
15 |
16 | This project is tested under the following environment setting.
17 | - OS: Ubuntu 16.04.3
18 | - GPU: Geforce 1080 Ti or Titan X(Pascal or Maxwell)
19 | - Cuda: 8.0, Cudnn: v5.1 or v7.03
20 | - Python: 2.7.14(setup with Miniconda2)
21 | - Theano: 0.9.0.dev-c697eeab84e5b8a74908da654b66ec9eca4f1291
22 | - Lasagne: 0.2.dev1
23 | - Parmesan: 0.1.dev1
24 |
25 | > Python
26 | > Numpy
27 | > Scipy
28 | > [Theano](https://github.com/Theano/Theano)
29 | > [Lasagne](https://github.com/Lasagne/Lasagne)(version 0.2.dev1)
30 | > [Parmesan](https://github.com/casperkaae/parmesan)
31 |
32 | Thank the authors of these libs. We also thank the authors of [Improved-GAN](https://github.com/openai/improved-gan) and [Temporal Ensemble](https://github.com/smlaine2/tempens) for providing their code. Our code is widely adapted from their repositories.
33 |
34 | ## Results
35 |
36 | Triple-GAN can achieve excellent classification results on MNIST, SVHN and CIFAR10 datasets, see the [paper](https://arxiv.org/abs/1703.02291) for a comparison with the previous state-of-the-art. See generated images as follows:
37 |
38 | ### Comparing Triple-GAN (right) with GAN trained with [feature matching](https://arxiv.org/abs/1606.03498) (left)
39 |
40 |
41 | ### Generating images in four specific classes (airplane, automobile, bird, horse)
42 |
43 |
44 |
45 | ### Disentangling styles from classes (left: data, right: Triple-GAN)
46 |
47 |
48 | ### Class-conditional linear interpolation on latent space
49 |
50 |
--------------------------------------------------------------------------------
/cifar10_data.py:
--------------------------------------------------------------------------------
1 | import cPickle
2 | import os
3 | import sys
4 | import tarfile
5 | from six.moves import urllib
6 | import numpy as np
7 |
8 | def maybe_download_and_extract(data_dir, url='http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'):
9 | if not os.path.exists(os.path.join(data_dir, 'cifar-10-batches-py')):
10 | if not os.path.exists(data_dir):
11 | os.makedirs(data_dir)
12 | filename = url.split('/')[-1]
13 | filepath = os.path.join(data_dir, filename)
14 | if not os.path.exists(filepath):
15 | def _progress(count, block_size, total_size):
16 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
17 | float(count * block_size) / float(total_size) * 100.0))
18 | sys.stdout.flush()
19 | filepath, _ = urllib.request.urlretrieve(url, filepath, _progress)
20 | print()
21 | statinfo = os.stat(filepath)
22 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
23 | tarfile.open(filepath, 'r:gz').extractall(data_dir)
24 |
25 | def unpickle(file):
26 | fo = open(file, 'rb')
27 | d = cPickle.load(fo)
28 | fo.close()
29 | return {'x': np.cast[np.float32]((-127.5 + d['data'].reshape((10000,3,32,32)))/128.), 'y': np.array(d['labels']).astype(np.int32)}
30 |
31 | def load(data_dir, subset='train'):
32 | maybe_download_and_extract(data_dir)
33 | if subset=='train':
34 | train_data = [unpickle(os.path.join(data_dir,'cifar-10-batches-py/data_batch_' + str(i))) for i in range(1,6)]
35 | trainx = np.concatenate([d['x'] for d in train_data],axis=0)
36 | trainy = np.concatenate([d['y'] for d in train_data],axis=0)
37 | return trainx, trainy
38 | elif subset=='test':
39 | test_data = unpickle(os.path.join(data_dir,'cifar-10-batches-py/test_batch'))
40 | testx = test_data['x']
41 | testy = test_data['y']
42 | return testx, testy
43 | else:
44 | raise NotImplementedError('subset should be either train or test')
45 |
--------------------------------------------------------------------------------
/components/__init__.py:
--------------------------------------------------------------------------------
1 | from .objectives import *
2 | from .shortcuts import *
--------------------------------------------------------------------------------
/components/objectives.py:
--------------------------------------------------------------------------------
1 | '''
2 | objectives
3 | '''
4 | import numpy as np
5 | import theano.tensor as T
6 | import theano
7 | import lasagne
8 | from theano.tensor.extra_ops import to_one_hot
9 |
10 | def categorical_crossentropy(predictions, targets, epsilon=1e-6):
11 | # avoid overflow
12 | predictions = T.clip(predictions, epsilon, 1-epsilon)
13 | # check shape of targets
14 | num_cls = predictions.shape[1]
15 | if targets.ndim == predictions.ndim - 1:
16 | targets = theano.tensor.extra_ops.to_one_hot(targets, num_cls)
17 | elif targets.ndim != predictions.ndim:
18 | raise TypeError('rank mismatch between targets and predictions')
19 | return lasagne.objectives.categorical_crossentropy(predictions, targets).mean()
20 |
21 | def entropy(predictions):
22 | return categorical_crossentropy(predictions, predictions)
23 |
24 | def negative_entropy_of_mean(predictions):
25 | return -entropy(predictions.mean(axis=0, keepdims=True))
26 |
27 | def categorical_crossentropy_of_mean(predictions):
28 | num_cls = predictions.shape[1]
29 | uniform_targets = T.ones((1, num_cls)) / num_cls
30 | return categorical_crossentropy(predictions.mean(axis=0, keepdims=True), uniform_targets)
31 |
32 | def categorical_crossentropy_ssl_alternative(predictions, targets, num_labelled, weight_decay, alpha_labeled=1., alpha_unlabeled=.3, alpha_average=1e-3, alpha_decay=1e-4):
33 | ce_loss = categorical_crossentropy(predictions[:num_labelled], targets)
34 | en_loss = entropy(predictions[num_labelled:])
35 | av_loss = negative_entropy_of_mean(predictions[num_labelled:])
36 | return alpha_labeled*ce_loss + alpha_unlabeled*en_loss + alpha_average*av_loss + alpha_decay*weight_decay
37 |
38 | def categorical_crossentropy_ssl(predictions, targets, num_labelled, weight_decay, alpha_labeled=1., alpha_unlabeled=.3, alpha_average=1e-3, alpha_decay=1e-4):
39 | ce_loss = categorical_crossentropy(predictions[:num_labelled], targets)
40 | en_loss = entropy(predictions[num_labelled:])
41 | av_loss = categorical_crossentropy_of_mean(predictions[num_labelled:])
42 | return alpha_labeled*ce_loss + alpha_unlabeled*en_loss + alpha_average*av_loss + alpha_decay*weight_decay
43 |
44 | def categorical_crossentropy_ssl_separated(predictions_l, targets, predictions_u, weight_decay, alpha_labeled=1., alpha_unlabeled=.3, alpha_average=1e-3, alpha_decay=1e-4):
45 | ce_loss = categorical_crossentropy(predictions_l, targets)
46 | en_loss = entropy(predictions_u)
47 | av_loss = categorical_crossentropy_of_mean(predictions_u)
48 | return alpha_labeled*ce_loss + alpha_unlabeled*en_loss + alpha_average*av_loss + alpha_decay*weight_decay
49 |
50 | def maximum_mean_discripancy(sample, data, sigma=[2. , 5., 10., 20., 40., 80.]):
51 | sample = sample.flatten(2)
52 | data = data.flatten(2)
53 |
54 | x = T.concatenate([sample, data], axis=0)
55 | xx = T.dot(x, x.T)
56 | x2 = T.sum(x*x, axis=1, keepdims=True)
57 | exponent = xx - .5*x2 - .5*x2.T
58 | s_samples = T.ones([sample.shape[0], 1])*1./sample.shape[0]
59 | s_data = -T.ones([data.shape[0], 1])*1./data.shape[0]
60 | s_all = T.concatenate([s_samples, s_data], axis=0)
61 | s_mat = T.dot(s_all, s_all.T)
62 | mmd_loss = 0.
63 | for s in sigma:
64 | kernel_val = T.exp((1./s) * exponent)
65 | mmd_loss += T.sum(s_mat*kernel_val)
66 | return T.sqrt(mmd_loss)
67 |
68 | def feature_matching(f_sample, f_data, norm='l2'):
69 | if norm == 'l2':
70 | return T.mean(T.square(T.mean(f_sample,axis=0)-T.mean(f_data,axis=0)))
71 | elif norm == 'l1':
72 | return T.mean(abs(T.mean(f_sample,axis=0)-T.mean(f_data,axis=0)))
73 | else:
74 | raise NotImplementedError
75 |
76 | def multiclass_s3vm_loss(predictions_l, targets, predictions_u, weight_decay, alpha_labeled=1., alpha_unlabeled=1., alpha_average=1., alpha_decay=1., delta=1., norm_type=2, form ='mean_class', entropy_term=False):
77 | '''
78 | predictions:
79 | size L x nc
80 | U x nc
81 | targets:
82 | size L x nc
83 |
84 | output:
85 | weighted sum of hinge loss, hat loss, balance constraint and weight decay
86 | '''
87 | num_cls = predictions_l.shape[1]
88 | if targets.ndim == predictions_l.ndim - 1:
89 | targets = theano.tensor.extra_ops.to_one_hot(targets, num_cls)
90 | elif targets.ndim != predictions_l.ndim:
91 | raise TypeError('rank mismatch between targets and predictions')
92 |
93 | hinge_loss = multiclass_hinge_loss_(predictions_l, targets, delta)
94 | hat_loss = multiclass_hat_loss(predictions_u, delta)
95 | regularization = balance_constraint(predictions_l, targets, predictions_u, norm_type, form)
96 | if not entropy_term:
97 | return alpha_labeled*hinge_loss.mean() + alpha_unlabeled*hat_loss.mean() + alpha_average*regularization + alpha_decay*weight_decay
98 | else:
99 | # given an unlabeled data, when treat hat loss as the entropy term derived from a lowerbound, it should conflict to current prediction, which is quite strange but true ... the entropy term enforce the discriminator to predict unlabeled data uniformly as a regularization
100 | # max entropy regularization provides a tighter lowerbound but hurt the semi-supervised learning performance as it conflicts to the hat loss ...
101 | return alpha_labeled*hinge_loss.mean() - alpha_unlabeled*hat_loss.mean() + alpha_average*regularization + alpha_decay*weight_decay
102 |
103 |
104 | def multiclass_hinge_loss_(predictions, targets, delta=1):
105 | return lasagne.objectives.multiclass_hinge_loss(predictions, targets, delta)
106 |
107 | def multiclass_hinge_loss(predictions, targets, weight_decay, alpha_decay=1., delta=1):
108 | return multiclass_hinge_loss_(predictions, targets, delta).mean() + alpha_decay*weight_decay
109 |
110 | def multiclass_hat_loss(predictions, delta=1):
111 | targets = T.argmax(predictions, axis=1)
112 | return multiclass_hinge_loss(predictions, targets, delta)
113 |
114 | def balance_constraint(p_l, t_l, p_u, norm_type=2, form='mean_class'):
115 | '''
116 | balance constraint
117 | ------
118 | norm_type: type of norm
119 | l2 or l1
120 | form: form of regularization
121 | mean_class: average mean activation of u and l data should be the same over each class
122 | mean_all: average mean activation of u and l data should be the same over all data
123 | ratio:
124 |
125 | '''
126 | t_u = T.argmax(p_u, axis=1)
127 | num_cls = p_l.shape[1]
128 | t_u = theano.tensor.extra_ops.to_one_hot(t_u, num_cls)
129 | if form == 'mean_class':
130 | res = (p_l*t_l).mean(axis=0) - (p_u*t_u).mean(axis=0)
131 | elif form == 'mean_all':
132 | res = p_l.mean(axis=0) - p_u.mean(axis=0)
133 | elif form == 'ratio':
134 | pass
135 |
136 | # res should be a vector with length number_class
137 | return res.norm(norm_type)
138 |
--------------------------------------------------------------------------------
/components/shortcuts.py:
--------------------------------------------------------------------------------
1 | '''
2 | shortcuts for compsited layers
3 | '''
4 | import numpy as np
5 | import theano.tensor as T
6 | import theano
7 | import lasagne
8 | import sys
9 | sys.path.append("..")
10 | from layers.merge import ConvConcatLayer, MLPConcatLayer
11 |
12 | # convolutional layer
13 | # following optional batch normalization, pooling and dropout
14 | def convlayer(l,bn,dr,ps,n_kerns,d_kerns,nonlinearity,pad,stride,W=lasagne.init.GlorotUniform(),b=lasagne.init.Constant(0.), name=None):
15 | l = lasagne.layers.Conv2DLayer(l, num_filters=n_kerns, filter_size=d_kerns, stride=stride, pad=pad, W=W, b=b, nonlinearity=nonlinearity, name=name)
16 | if bn:
17 | l = lasagne.layers.batch_norm(l)
18 | if ps > 1:
19 | l = lasagne.layers.MaxPool2DLayer(l, pool_size=(ps,ps))
20 | if dr > 0:
21 | l = lasagne.layers.DropoutLayer(l, p=dr)
22 | return l
23 |
24 | # mlp layer
25 | # following optional batch normalization and dropout
26 | def mlplayer(l,bn,dr,num_units,nonlinearity,name):
27 | l = lasagne.layers.DenseLayer(l,num_units=num_units,nonlinearity=nonlinearity,name="MLP-"+name)
28 | if bn:
29 | l = lasagne.layers.batch_norm(l, name="BN-"+name)
30 | if dr > 0:
31 | l = lasagne.layers.DropoutLayer(l, p=dr, name="Drop-"+name)
32 | return l
33 |
--------------------------------------------------------------------------------
/generate_given_models.py:
--------------------------------------------------------------------------------
1 | '''
2 | This code generates data in various ways given a trained Triple-GAN
3 |
4 | Note: Due to the effect of Batch Normalization, it is better to generate batch_size_g (see train file, 200 for cifiar10) samples distributed equally across class in each batch.
5 | '''
6 | import gzip, os, cPickle, time, math, argparse, shutil, sys
7 |
8 | import numpy as np
9 | import theano, lasagne
10 | import theano.tensor as T
11 | import lasagne.layers as ll
12 | import lasagne.nonlinearities as ln
13 | from lasagne.layers import dnn
14 | import nn
15 | from lasagne.init import Normal
16 | from theano.sandbox.rng_mrg import MRG_RandomStreams
17 |
18 | from layers.merge import ConvConcatLayer, MLPConcatLayer
19 | from layers.deconv import Deconv2DLayer
20 |
21 | from components.shortcuts import convlayer, mlplayer
22 | from components.objectives import categorical_crossentropy_ssl_separated, maximum_mean_discripancy, categorical_crossentropy, feature_matching
23 | from utils.create_ssl_data import create_ssl_data, create_ssl_data_subset
24 | from utils.others import get_nonlin_list, get_pad_list, bernoullisample, printarray_2D, array2file_2D
25 | import utils.paramgraphics as paramgraphics
26 | from utils.checkpoints import load_weights
27 |
28 | # global
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument("-oldmodel", type=str, default=argparse.SUPPRESS)
31 | parser.add_argument("-dataset", type=str, default='svhn')
32 | args = parser.parse_args()
33 |
34 | filename_script=os.path.basename(os.path.realpath(__file__))
35 | outfolder=os.path.join("results-ssl", os.path.splitext(filename_script)[0])
36 | outfolder+='.'
37 | outfolder+=args.dataset
38 | outfolder+='.'
39 | outfolder+=str(int(time.time()))
40 | if not os.path.exists(outfolder):
41 | os.makedirs(outfolder)
42 | shutil.copy(os.path.realpath(__file__), os.path.join(outfolder, filename_script))
43 |
44 | # seeds
45 | seed=1234
46 | rng=np.random.RandomState(seed)
47 | theano_rng=MRG_RandomStreams(rng.randint(2 ** 15))
48 | lasagne.random.set_rng(np.random.RandomState(rng.randint(2 ** 15)))
49 |
50 | # G
51 | n_z=100
52 | batch_size_g=200
53 | num_x=50000
54 | # data dependent
55 | if args.dataset == 'svhn' or args.dataset == 'cifar10':
56 | gen_final_non=ln.tanh
57 | num_classes=10
58 | dim_input=(32,32)
59 | in_channels=3
60 | colorImg=True
61 | generation_scale=True
62 | elif args.dataset == 'mnist':
63 | gen_final_non=ln.sigmoid
64 | num_classes=10
65 | dim_input=(28,28)
66 | in_channels=1
67 | colorImg=False
68 | generation_scale=False
69 |
70 | '''
71 | models
72 | '''
73 | # symbols
74 | sym_y_g = T.ivector()
75 | sym_z_input = T.matrix()
76 | sym_z_rand = theano_rng.uniform(size=(batch_size_g, n_z))
77 | sym_z_shared = T.tile(theano_rng.uniform((batch_size_g/num_classes, n_z)), (num_classes, 1))
78 |
79 | # generator y2x: p_g(x, y) = p(y) p_g(x | y) where x = G(z, y), z follows p_g(z)
80 | gen_in_z = ll.InputLayer(shape=(None, n_z))
81 | gen_in_y = ll.InputLayer(shape=(None,))
82 | gen_layers = [gen_in_z]
83 | if args.dataset == 'svhn' or args.dataset == 'cifar10':
84 | gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-00'))
85 | gen_layers.append(nn.batch_norm(ll.DenseLayer(gen_layers[-1], num_units=4*4*512, W=Normal(0.05), nonlinearity=nn.relu, name='gen-01'), g=None, name='gen-02'))
86 | gen_layers.append(ll.ReshapeLayer(gen_layers[-1], (-1,512,4,4), name='gen-03'))
87 | gen_layers.append(ConvConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-10'))
88 | gen_layers.append(nn.batch_norm(nn.Deconv2DLayer(gen_layers[-1], (None,256,8,8), (5,5), W=Normal(0.05), nonlinearity=nn.relu, name='gen-11'), g=None, name='gen-12')) # 4 -> 8
89 | gen_layers.append(ConvConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-20'))
90 | gen_layers.append(nn.batch_norm(nn.Deconv2DLayer(gen_layers[-1], (None,128,16,16), (5,5), W=Normal(0.05), nonlinearity=nn.relu, name='gen-21'), g=None, name='gen-22')) # 8 -> 16
91 | gen_layers.append(ConvConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-30'))
92 | gen_layers.append(nn.weight_norm(nn.Deconv2DLayer(gen_layers[-1], (None,3,32,32), (5,5), W=Normal(0.05), nonlinearity=gen_final_non, name='gen-31'), train_g=True, init_stdv=0.1, name='gen-32')) # 16 -> 32
93 | elif args.dataset == 'mnist':
94 | gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-1'))
95 | gen_layers.append(ll.batch_norm(ll.DenseLayer(gen_layers[-1], num_units=500, nonlinearity=ln.softplus, name='gen-2'), name='gen-3'))
96 | gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-4'))
97 | gen_layers.append(ll.batch_norm(ll.DenseLayer(gen_layers[-1], num_units=500, nonlinearity=ln.softplus, name='gen-5'), name='gen-6'))
98 | gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-7'))
99 | gen_layers.append(nn.l2normalize(ll.DenseLayer(gen_layers[-1], num_units=28**2, nonlinearity=gen_final_non, name='gen-8')))
100 |
101 | # outputs
102 | gen_out_x = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_rand}, deterministic=False)
103 | gen_out_x_shared = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_shared}, deterministic=False)
104 | gen_out_x_interpolation = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_input}, deterministic=False)
105 | generate = theano.function(inputs=[sym_y_g], outputs=gen_out_x)
106 | generate_shared = theano.function(inputs=[sym_y_g], outputs=gen_out_x_shared)
107 | generate_interpolation = theano.function(inputs=[sym_y_g, sym_z_input], outputs=gen_out_x_interpolation)
108 |
109 | '''
110 | Load pretrained model
111 | '''
112 | load_weights(args.oldmodel, gen_layers)
113 |
114 | # interpolation on latent space (z) class conditionally
115 | for i in xrange(10):
116 | sample_y = np.int32(np.repeat(np.arange(num_classes), batch_size_g/num_classes))
117 | orignial_z = np.repeat(rng.uniform(size=(num_classes,n_z)), batch_size_g/num_classes, axis=0)
118 | target_z = np.repeat(rng.uniform(size=(num_classes,n_z)), batch_size_g/num_classes, axis=0)
119 | alpha = np.tile(np.arange(batch_size_g/num_classes) * 1.0 / (batch_size_g/num_classes-1), num_classes)
120 | alpha = alpha.reshape(-1,1)
121 | z = np.float32((1-alpha)*orignial_z+alpha*target_z)
122 | x_gen_batch = generate_interpolation(sample_y, z)
123 | x_gen_batch = x_gen_batch.reshape((batch_size_g,-1))
124 | image = paramgraphics.mat_to_img(x_gen_batch.T, dim_input, colorImg=colorImg, tile_shape=(num_classes, 2*num_classes), scale=generation_scale, save_path=os.path.join(outfolder, 'interpolation-'+str(i)+'.png'))
125 |
126 | # class conditionally generation with shared z and fixed y
127 | for i in xrange(10):
128 | sample_y = np.int32(np.repeat(np.arange(num_classes), batch_size_g/num_classes))
129 | x_gen_batch = generate_shared(sample_y)
130 | x_gen_batch = x_gen_batch.reshape((batch_size_g,-1))
131 | image = paramgraphics.mat_to_img(x_gen_batch.T, dim_input, colorImg=colorImg, tile_shape=(num_classes, 2*num_classes), scale=generation_scale, save_path=os.path.join(outfolder, 'shared-'+str(i)+'.png'))
132 |
133 | # generation with randomly sampled z and y
134 | for i in xrange(10):
135 | sample_y = np.int32(np.repeat(np.arange(num_classes), batch_size_g/num_classes))
136 | inds = np.random.permutation(batch_size_g)
137 | sample_y = sample_y[inds]
138 | x_gen_batch = generate(sample_y)
139 | x_gen_batch = x_gen_batch.reshape((batch_size_g,-1))
140 | image = paramgraphics.mat_to_img(x_gen_batch.T, dim_input, colorImg=colorImg, tile_shape=(num_classes, 2*num_classes), scale=generation_scale, save_path=os.path.join(outfolder, 'random-'+str(i)+'.png'))
141 |
142 | if args.dataset != 'cifar10':
143 | exit()
144 |
145 | # large number of random generation for inception score computation
146 | x_gen = []
147 | # generation for each class
148 | x_classes = []
149 | for i in xrange(num_classes):
150 | x_classes.append([])
151 | for i in xrange(num_x / batch_size_g):
152 | print i
153 | sample_y = np.int32(np.repeat(np.arange(num_classes), batch_size_g/num_classes))
154 | x_gen_batch = generate(sample_y)
155 | x_gen.append(x_gen_batch)
156 | if i < 5:
157 | for j in xrange(num_classes):
158 | x_classes[j].append(x_gen_batch[j*20:(j+1)*20])
159 | if i == 5:
160 | for ind in xrange(num_classes):
161 | x_classes[ind] = np.concatenate(x_classes[ind], axis=0)
162 | image = paramgraphics.mat_to_img(x_classes[ind].T, dim_input, colorImg=colorImg, tile_shape=(num_classes, num_classes), scale=generation_scale, save_path=os.path.join(outfolder, 'class-'+str(ind)+'.png'))
163 |
164 | x_gen=np.concatenate(x_gen, axis=0)
165 | np.save(os.path.join(outfolder,'inception_score'), x_gen)
166 |
167 |
168 |
169 |
--------------------------------------------------------------------------------
/images/class-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhenxuan00/triple-gan/9691c007eab21322657036f81ff93ef5634652fa/images/class-0.png
--------------------------------------------------------------------------------
/images/class-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhenxuan00/triple-gan/9691c007eab21322657036f81ff93ef5634652fa/images/class-1.png
--------------------------------------------------------------------------------
/images/class-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhenxuan00/triple-gan/9691c007eab21322657036f81ff93ef5634652fa/images/class-2.png
--------------------------------------------------------------------------------
/images/class-7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhenxuan00/triple-gan/9691c007eab21322657036f81ff93ef5634652fa/images/class-7.png
--------------------------------------------------------------------------------
/images/mnist_fm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhenxuan00/triple-gan/9691c007eab21322657036f81ff93ef5634652fa/images/mnist_fm.png
--------------------------------------------------------------------------------
/images/mnist_random.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhenxuan00/triple-gan/9691c007eab21322657036f81ff93ef5634652fa/images/mnist_random.png
--------------------------------------------------------------------------------
/images/svhn_data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhenxuan00/triple-gan/9691c007eab21322657036f81ff93ef5634652fa/images/svhn_data.png
--------------------------------------------------------------------------------
/images/svhn_linear.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhenxuan00/triple-gan/9691c007eab21322657036f81ff93ef5634652fa/images/svhn_linear.png
--------------------------------------------------------------------------------
/images/svhn_share_z.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhenxuan00/triple-gan/9691c007eab21322657036f81ff93ef5634652fa/images/svhn_share_z.png
--------------------------------------------------------------------------------
/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .merge import *
2 | from .sample import *
3 | from .deconv import *
--------------------------------------------------------------------------------
/layers/deconv.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import theano as th
3 | import theano.tensor as T
4 | import lasagne
5 | from lasagne.layers import dnn
6 |
7 | # code from ImprovedGAN, Salimans and Goodfellow: https://github.com/openai/improved-gan
8 |
9 | class Deconv2DLayer(lasagne.layers.Layer):
10 | def __init__(self, incoming, target_shape, filter_size, stride=(2, 2),
11 | W=lasagne.init.Normal(0.05), b=lasagne.init.Constant(0.), nonlinearity=None, **kwargs):
12 | super(Deconv2DLayer, self).__init__(incoming, **kwargs)
13 | self.target_shape = target_shape
14 | self.nonlinearity = (lasagne.nonlinearities.identity if nonlinearity is None else nonlinearity)
15 | self.filter_size = lasagne.layers.dnn.as_tuple(filter_size, 2)
16 | self.stride = lasagne.layers.dnn.as_tuple(stride, 2)
17 | self.target_shape = target_shape
18 |
19 | self.W_shape = (incoming.output_shape[1], target_shape[1], filter_size[0], filter_size[1])
20 | self.W = self.add_param(W, self.W_shape, name="W")
21 | if b is not None:
22 | self.b = self.add_param(b, (target_shape[1],), name="b")
23 | else:
24 | self.b = None
25 |
26 | def get_output_for(self, input, **kwargs):
27 | op = T.nnet.abstract_conv.AbstractConv2d_gradInputs(imshp=self.target_shape, kshp=self.W_shape, subsample=self.stride, border_mode='half')
28 | activation = op(self.W, input, self.target_shape[2:])
29 |
30 | if self.b is not None:
31 | activation += self.b.dimshuffle('x', 0, 'x', 'x')
32 |
33 | return self.nonlinearity(activation)
34 |
35 | def get_output_shape_for(self, input_shape):
36 | return self.target_shape
--------------------------------------------------------------------------------
/layers/merge.py:
--------------------------------------------------------------------------------
1 | import lasagne
2 | from lasagne import init
3 | from lasagne import nonlinearities
4 |
5 | import theano.tensor as T
6 | import theano
7 | import numpy as np
8 | import theano.tensor.extra_ops as Textra
9 |
10 | __all__ = [
11 | "ConvConcatLayer", #
12 | "MLPConcatLayer", #
13 | ]
14 |
15 |
16 | class ConvConcatLayer(lasagne.layers.MergeLayer):
17 | '''
18 | concatenate a tensor and a vector on feature map axis
19 | '''
20 | def __init__(self, incomings, num_cls, **kwargs):
21 | super(ConvConcatLayer, self).__init__(incomings, **kwargs)
22 | self.num_cls = num_cls
23 |
24 | def get_output_shape_for(self, input_shapes):
25 | res = list(input_shapes[0])
26 | res[1] += self.num_cls
27 | return tuple(res)
28 |
29 | def get_output_for(self, input, **kwargs):
30 | x, y = input
31 | if y.ndim == 1:
32 | y = T.extra_ops.to_one_hot(y, self.num_cls)
33 | if y.ndim == 2:
34 | y = y.dimshuffle(0, 1, 'x', 'x')
35 | assert y.ndim == 4
36 | return T.concatenate([x, y*T.ones((x.shape[0], y.shape[1], x.shape[2], x.shape[3]))], axis=1)
37 |
38 | class MLPConcatLayer(lasagne.layers.MergeLayer):
39 | '''
40 | concatenate a matrix and a vector on feature axis
41 | '''
42 | def __init__(self, incomings, num_cls, **kwargs):
43 | super(MLPConcatLayer, self).__init__(incomings, **kwargs)
44 | self.num_cls = num_cls
45 |
46 | def get_output_shape_for(self, input_shapes):
47 | res = list(input_shapes[0])
48 | res[1] += self.num_cls
49 | return tuple(res)
50 |
51 | def get_output_for(self, input, **kwargs):
52 | x, y = input
53 | if y.ndim == 1:
54 | y = T.extra_ops.to_one_hot(y, self.num_cls)
55 | assert y.ndim == 2
56 | return T.concatenate([x, y], axis=1)
--------------------------------------------------------------------------------
/layers/sample.py:
--------------------------------------------------------------------------------
1 | import lasagne
2 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
3 | import theano.tensor as T
4 | import theano
5 |
6 |
7 | class GaussianSampleLayer(lasagne.layers.MergeLayer):
8 | """
9 | Simple sampling layer drawing a single Monte Carlo sample to approximate
10 | E_q [log( p(x,z) / q(z|x) )]. This is the approach described in [KINGMA]_.
11 | Parameters
12 | ----------
13 | mu, log_var : :class:`Layer` instances
14 | Parameterizing the mean and log(variance) of the distribution to sample
15 | from as described in [KINGMA]_. The code assumes that these have the
16 | same number of dimensions.
17 | seed : int
18 | seed to random stream
19 | Methods
20 | ----------
21 | seed : Helper function to change the random seed after init is called
22 | References
23 | ----------
24 | .. [KINGMA] Kingma, Diederik P., and Max Welling.
25 | "Auto-Encoding Variational Bayes."
26 | arXiv preprint arXiv:1312.6114 (2013).
27 | """
28 | def __init__(self, mean, log_var,
29 | seed=None,
30 | **kwargs):
31 | super(SimpleSampleLayer, self).__init__([mean, log_var], **kwargs)
32 |
33 | if seed is None:
34 | seed = lasagne.random.get_rng().randint(1, 2147462579)
35 | self._srng = RandomStreams(seed)
36 |
37 | def seed(self, seed=None):
38 | if seed is None:
39 | seed = lasagne.random.get_rng().randint(1, 2147462579)
40 | self._srng.seed(seed)
41 |
42 | def get_output_shape_for(self, input_shapes):
43 | return input_shapes[0]
44 |
45 | def get_output_for(self, input, **kwargs):
46 | mu, log_var = input
47 | eps = self._srng.normal(mu.shape)
48 | z = mu + T.exp(0.5 * log_var) * eps
49 | return z
50 |
51 | class BernoulliSampleLayer(lasagne.layers.Layer):
52 | """
53 | Simple sampling layer drawing samples from bernoulli distributions.
54 | Parameters
55 | ----------
56 | mean : :class:`Layer` instances
57 | Parameterizing the mean value of each bernoulli distribution
58 | seed : int
59 | seed to random stream
60 | Methods
61 | ----------
62 | seed : Helper function to change the random seed after init is called
63 | """
64 |
65 | def __init__(self, mean,
66 | seed=None,
67 | **kwargs):
68 | super(SimpleBernoulliSampleLayer, self).__init__(mean, **kwargs)
69 |
70 | if seed is None:
71 | seed = lasagne.random.get_rng().randint(1, 2147462579)
72 |
73 | self._srng = RandomStreams(seed)
74 |
75 | def seed(self, seed=lasagne.random.get_rng().randint(1, 2147462579)):
76 | self._srng.seed(seed)
77 |
78 | def get_output_shape_for(self, input_shape):
79 | return input_shape
80 |
81 | def get_output_for(self, mu, **kwargs):
82 | return self._srng.binomial(size=mu.shape, p=mu, dtype=mu.dtype)
--------------------------------------------------------------------------------
/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | neural network stuff, intended to be used with Lasagne
3 | """
4 |
5 | import numpy as np
6 | import theano as th
7 | import theano.tensor as T
8 | import lasagne
9 | from lasagne.layers import dnn
10 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
11 |
12 | # T.nnet.relu has some stability issues, this is better
13 | def relu(x):
14 | return T.maximum(x, 0)
15 |
16 | def lrelu(x, a=0.2):
17 | return T.maximum(x, a*x)
18 |
19 | def centered_softplus(x):
20 | return T.nnet.softplus(x) - np.cast[th.config.floatX](np.log(2.))
21 |
22 | def log_sum_exp(x, axis=1):
23 | m = T.max(x, axis=axis)
24 | return m+T.log(T.sum(T.exp(x-m.dimshuffle(0,'x')), axis=axis))
25 |
26 | def adam_updates(params, cost, lr=0.001, mom1=0.9, mom2=0.999):
27 | updates = []
28 | grads = T.grad(cost, params)
29 | t = th.shared(np.cast[th.config.floatX](1.))
30 | for p, g in zip(params, grads):
31 | value = p.get_value(borrow=True)
32 | v = th.shared(np.zeros(value.shape, dtype=value.dtype),
33 | broadcastable=p.broadcastable)
34 | mg = th.shared(np.zeros(value.shape, dtype=value.dtype),
35 | broadcastable=p.broadcastable)
36 |
37 | v_t = mom1*v + (1. - mom1)*g
38 | mg_t = mom2*mg + (1. - mom2)*T.square(g)
39 | v_hat = v_t / (1. - mom1 ** t)
40 | mg_hat = mg_t / (1. - mom2 ** t)
41 | g_t = v_hat / T.sqrt(mg_hat + 1e-8)
42 | p_t = p - lr * g_t
43 | updates.append((v, v_t))
44 | updates.append((mg, mg_t))
45 | updates.append((p, p_t))
46 | updates.append((t, t+1))
47 | return updates
48 |
49 | class WeightNormLayer(lasagne.layers.Layer):
50 | def __init__(self, incoming, b=lasagne.init.Constant(0.), g=lasagne.init.Constant(1.),
51 | W=lasagne.init.Normal(0.05), train_g=False, init_stdv=1., nonlinearity=relu, **kwargs):
52 | super(WeightNormLayer, self).__init__(incoming, **kwargs)
53 | self.nonlinearity = nonlinearity
54 | self.init_stdv = init_stdv
55 | k = self.input_shape[1]
56 | if b is not None:
57 | self.b = self.add_param(b, (k,), name="b", regularizable=False)
58 | if g is not None:
59 | self.g = self.add_param(g, (k,), name="g", regularizable=False, trainable=train_g)
60 | if len(self.input_shape)==4:
61 | self.axes_to_sum = (0,2,3)
62 | self.dimshuffle_args = ['x',0,'x','x']
63 | else:
64 | self.axes_to_sum = 0
65 | self.dimshuffle_args = ['x',0]
66 |
67 | # scale weights in layer below
68 | incoming.W_param = incoming.W
69 | #incoming.W_param.set_value(W.sample(incoming.W_param.get_value().shape))
70 | if incoming.W_param.ndim==4:
71 | if isinstance(incoming, Deconv2DLayer):
72 | W_axes_to_sum = (0,2,3)
73 | W_dimshuffle_args = ['x',0,'x','x']
74 | else:
75 | W_axes_to_sum = (1,2,3)
76 | W_dimshuffle_args = [0,'x','x','x']
77 | else:
78 | W_axes_to_sum = 0
79 | W_dimshuffle_args = ['x',0]
80 | if g is not None:
81 | incoming.W = incoming.W_param * (self.g/T.sqrt(1e-6 + T.sum(T.square(incoming.W_param),axis=W_axes_to_sum))).dimshuffle(*W_dimshuffle_args)
82 | else:
83 | incoming.W = incoming.W_param / T.sqrt(1e-6 + T.sum(T.square(incoming.W_param),axis=W_axes_to_sum,keepdims=True))
84 |
85 | def get_output_for(self, input, init=False, **kwargs):
86 | if init:
87 | m = T.mean(input, self.axes_to_sum)
88 | input -= m.dimshuffle(*self.dimshuffle_args)
89 | inv_stdv = self.init_stdv/T.sqrt(T.mean(T.square(input), self.axes_to_sum))
90 | input *= inv_stdv.dimshuffle(*self.dimshuffle_args)
91 | self.init_updates = [(self.b, -m*inv_stdv), (self.g, self.g*inv_stdv)]
92 | elif hasattr(self,'b'):
93 | input += self.b.dimshuffle(*self.dimshuffle_args)
94 |
95 | return self.nonlinearity(input)
96 |
97 | def weight_norm(layer, **kwargs):
98 | nonlinearity = getattr(layer, 'nonlinearity', None)
99 | if nonlinearity is not None:
100 | layer.nonlinearity = lasagne.nonlinearities.identity
101 | if hasattr(layer, 'b'):
102 | del layer.params[layer.b]
103 | layer.b = None
104 | return WeightNormLayer(layer, nonlinearity=nonlinearity, **kwargs)
105 |
106 | class Deconv2DLayer(lasagne.layers.Layer):
107 | def __init__(self, incoming, target_shape, filter_size, stride=(2, 2),
108 | W=lasagne.init.Normal(0.05), b=lasagne.init.Constant(0.), nonlinearity=relu, **kwargs):
109 | super(Deconv2DLayer, self).__init__(incoming, **kwargs)
110 | self.target_shape = target_shape
111 | self.nonlinearity = (lasagne.nonlinearities.identity if nonlinearity is None else nonlinearity)
112 | self.filter_size = lasagne.layers.dnn.as_tuple(filter_size, 2)
113 | self.stride = lasagne.layers.dnn.as_tuple(stride, 2)
114 | self.target_shape = target_shape
115 |
116 | self.W_shape = (incoming.output_shape[1], target_shape[1], filter_size[0], filter_size[1])
117 | self.W = self.add_param(W, self.W_shape, name="W")
118 | if b is not None:
119 | self.b = self.add_param(b, (target_shape[1],), name="b")
120 | else:
121 | self.b = None
122 |
123 | def get_output_for(self, input, **kwargs):
124 | op = T.nnet.abstract_conv.AbstractConv2d_gradInputs(imshp=self.target_shape, kshp=self.W_shape, subsample=self.stride, border_mode='half')
125 | activation = op(self.W, input, self.target_shape[2:])
126 |
127 | if self.b is not None:
128 | activation += self.b.dimshuffle('x', 0, 'x', 'x')
129 |
130 | return self.nonlinearity(activation)
131 |
132 | def get_output_shape_for(self, input_shape):
133 | return self.target_shape
134 |
135 | # minibatch discrimination layer
136 | class MinibatchLayer(lasagne.layers.Layer):
137 | def __init__(self, incoming, num_kernels, dim_per_kernel=5, theta=lasagne.init.Normal(0.05),
138 | log_weight_scale=lasagne.init.Constant(0.), b=lasagne.init.Constant(-1.), **kwargs):
139 | super(MinibatchLayer, self).__init__(incoming, **kwargs)
140 | self.num_kernels = num_kernels
141 | num_inputs = int(np.prod(self.input_shape[1:]))
142 | self.theta = self.add_param(theta, (num_inputs, num_kernels, dim_per_kernel), name="theta")
143 | self.log_weight_scale = self.add_param(log_weight_scale, (num_kernels, dim_per_kernel), name="log_weight_scale")
144 | self.W = self.theta * (T.exp(self.log_weight_scale)/T.sqrt(T.sum(T.square(self.theta),axis=0))).dimshuffle('x',0,1)
145 | self.b = self.add_param(b, (num_kernels,), name="b")
146 |
147 | def get_output_shape_for(self, input_shape):
148 | return (input_shape[0], np.prod(input_shape[1:])+self.num_kernels)
149 |
150 | def get_output_for(self, input, init=False, **kwargs):
151 | if input.ndim > 2:
152 | # if the input has more than two dimensions, flatten it into a
153 | # batch of feature vectors.
154 | input = input.flatten(2)
155 |
156 | activation = T.tensordot(input, self.W, [[1], [0]])
157 | abs_dif = (T.sum(abs(activation.dimshuffle(0,1,2,'x') - activation.dimshuffle('x',1,2,0)),axis=2)
158 | + 1e6 * T.eye(input.shape[0]).dimshuffle(0,'x',1))
159 |
160 | if init:
161 | mean_min_abs_dif = 0.5 * T.mean(T.min(abs_dif, axis=2),axis=0)
162 | abs_dif /= mean_min_abs_dif.dimshuffle('x',0,'x')
163 | self.init_updates = [(self.log_weight_scale, self.log_weight_scale-T.log(mean_min_abs_dif).dimshuffle(0,'x'))]
164 |
165 | f = T.sum(T.exp(-abs_dif),axis=2)
166 |
167 | if init:
168 | mf = T.mean(f,axis=0)
169 | f -= mf.dimshuffle('x',0)
170 | self.init_updates.append((self.b, -mf))
171 | else:
172 | f += self.b.dimshuffle('x',0)
173 |
174 | return T.concatenate([input, f], axis=1)
175 |
176 | class BatchNormLayer(lasagne.layers.Layer):
177 | def __init__(self, incoming, b=lasagne.init.Constant(0.), g=lasagne.init.Constant(1.), nonlinearity=relu, **kwargs):
178 | super(BatchNormLayer, self).__init__(incoming, **kwargs)
179 | self.nonlinearity = nonlinearity
180 | k = self.input_shape[1]
181 | if b is not None:
182 | self.b = self.add_param(b, (k,), name="b", regularizable=False)
183 | if g is not None:
184 | self.g = self.add_param(g, (k,), name="g", regularizable=False)
185 | self.avg_batch_mean = self.add_param(lasagne.init.Constant(0.), (k,), name="avg_batch_mean", regularizable=False, trainable=False)
186 | self.avg_batch_var = self.add_param(lasagne.init.Constant(1.), (k,), name="avg_batch_var", regularizable=False, trainable=False)
187 | if len(self.input_shape)==4:
188 | self.axes_to_sum = (0,2,3)
189 | self.dimshuffle_args = ['x',0,'x','x']
190 | else:
191 | self.axes_to_sum = 0
192 | self.dimshuffle_args = ['x',0]
193 |
194 | def get_output_for(self, input, deterministic=False, **kwargs):
195 | if deterministic:
196 | norm_features = (input-self.avg_batch_mean.dimshuffle(*self.dimshuffle_args)) / T.sqrt(1e-6 + self.avg_batch_var).dimshuffle(*self.dimshuffle_args)
197 | else:
198 | batch_mean = T.mean(input,axis=self.axes_to_sum).flatten()
199 | centered_input = input-batch_mean.dimshuffle(*self.dimshuffle_args)
200 | batch_var = T.mean(T.square(centered_input),axis=self.axes_to_sum).flatten()
201 | batch_stdv = T.sqrt(1e-6 + batch_var)
202 | norm_features = centered_input / batch_stdv.dimshuffle(*self.dimshuffle_args)
203 |
204 | # BN updates
205 | new_m = 0.9*self.avg_batch_mean + 0.1*batch_mean
206 | new_v = 0.9*self.avg_batch_var + T.cast((0.1*input.shape[0])/(input.shape[0]-1),th.config.floatX)*batch_var
207 | self.bn_updates = [(self.avg_batch_mean, new_m), (self.avg_batch_var, new_v)]
208 |
209 | if hasattr(self, 'g'):
210 | activation = norm_features*self.g.dimshuffle(*self.dimshuffle_args)
211 | else:
212 | activation = norm_features
213 | if hasattr(self, 'b'):
214 | activation += self.b.dimshuffle(*self.dimshuffle_args)
215 |
216 | return self.nonlinearity(activation)
217 |
218 | def batch_norm(layer, b=lasagne.init.Constant(0.), g=lasagne.init.Constant(1.), **kwargs):
219 | """
220 | adapted from https://gist.github.com/f0k/f1a6bd3c8585c400c190
221 | """
222 | nonlinearity = getattr(layer, 'nonlinearity', None)
223 | if nonlinearity is not None:
224 | layer.nonlinearity = lasagne.nonlinearities.identity
225 | else:
226 | nonlinearity = lasagne.nonlinearities.identity
227 | if hasattr(layer, 'b'):
228 | del layer.params[layer.b]
229 | layer.b = None
230 | return BatchNormLayer(layer, b, g, nonlinearity=nonlinearity, **kwargs)
231 |
232 | class GaussianNoiseLayer(lasagne.layers.Layer):
233 | def __init__(self, incoming, sigma=0.1, **kwargs):
234 | super(GaussianNoiseLayer, self).__init__(incoming, **kwargs)
235 | self._srng = RandomStreams(lasagne.random.get_rng().randint(1, 2147462579))
236 | self.sigma = sigma
237 |
238 | def get_output_for(self, input, deterministic=False, use_last_noise=False, **kwargs):
239 | if deterministic or self.sigma == 0:
240 | return input
241 | else:
242 | if not use_last_noise:
243 | self.noise = self._srng.normal(input.shape, avg=0.0, std=self.sigma)
244 | return input + self.noise
245 |
246 |
247 | # /////////// older code used for MNIST ////////////
248 |
249 | # weight normalization
250 | def l2normalize(layer, train_scale=True):
251 | W_param = layer.W
252 | s = W_param.get_value().shape
253 | if len(s)==4:
254 | axes_to_sum = (1,2,3)
255 | dimshuffle_args = [0,'x','x','x']
256 | k = s[0]
257 | else:
258 | axes_to_sum = 0
259 | dimshuffle_args = ['x',0]
260 | k = s[1]
261 | layer.W_scale = layer.add_param(lasagne.init.Constant(1.),
262 | (k,), name="W_scale", trainable=train_scale, regularizable=False)
263 | layer.W = W_param * (layer.W_scale/T.sqrt(1e-6 + T.sum(T.square(W_param),axis=axes_to_sum))).dimshuffle(*dimshuffle_args)
264 | return layer
265 |
266 | # fully connected layer with weight normalization
267 | class DenseLayer(lasagne.layers.Layer):
268 | def __init__(self, incoming, num_units, theta=lasagne.init.Normal(0.1), b=lasagne.init.Constant(0.),
269 | weight_scale=lasagne.init.Constant(1.), train_scale=False, nonlinearity=relu, **kwargs):
270 | super(DenseLayer, self).__init__(incoming, **kwargs)
271 | self.nonlinearity = (lasagne.nonlinearities.identity if nonlinearity is None else nonlinearity)
272 | self.num_units = num_units
273 | num_inputs = int(np.prod(self.input_shape[1:]))
274 | self.theta = self.add_param(theta, (num_inputs, num_units), name="theta")
275 | self.weight_scale = self.add_param(weight_scale, (num_units,), name="weight_scale", trainable=train_scale)
276 | self.W = self.theta * (self.weight_scale/T.sqrt(T.sum(T.square(self.theta),axis=0))).dimshuffle('x',0)
277 | self.b = self.add_param(b, (num_units,), name="b")
278 |
279 | def get_output_shape_for(self, input_shape):
280 | return (input_shape[0], self.num_units)
281 |
282 | def get_output_for(self, input, init=False, deterministic=False, **kwargs):
283 | if input.ndim > 2:
284 | # if the input has more than two dimensions, flatten it into a
285 | # batch of feature vectors.
286 | input = input.flatten(2)
287 |
288 | activation = T.dot(input, self.W)
289 |
290 | if init:
291 | ma = T.mean(activation, axis=0)
292 | activation -= ma.dimshuffle('x',0)
293 | stdv = T.sqrt(T.mean(T.square(activation),axis=0))
294 | activation /= stdv.dimshuffle('x',0)
295 | self.init_updates = [(self.weight_scale, self.weight_scale/stdv), (self.b, -ma/stdv)]
296 | else:
297 | activation += self.b.dimshuffle('x', 0)
298 |
299 | return self.nonlinearity(activation)
300 |
301 | from scipy import linalg
302 | class ZCA(object):
303 | def __init__(self, regularization=1e-5, x=None):
304 | self.regularization = regularization
305 | if x is not None:
306 | self.fit(x)
307 |
308 | def fit(self, x):
309 | s = x.shape
310 | x = x.copy().reshape((s[0],np.prod(s[1:])))
311 | m = np.mean(x, axis=0)
312 | x -= m
313 | sigma = np.dot(x.T,x) / x.shape[0]
314 | U, S, V = linalg.svd(sigma)
315 | tmp = np.dot(U, np.diag(1./np.sqrt(S+self.regularization)))
316 | tmp2 = np.dot(U, np.diag(np.sqrt(S+self.regularization)))
317 | self.ZCA_mat = th.shared(np.dot(tmp, U.T).astype(th.config.floatX))
318 | self.inv_ZCA_mat = th.shared(np.dot(tmp2, U.T).astype(th.config.floatX))
319 | self.mean = th.shared(m.astype(th.config.floatX))
320 |
321 | def apply(self, x):
322 | s = x.shape
323 | if isinstance(x, np.ndarray):
324 | return np.dot(x.reshape((s[0],np.prod(s[1:]))) - self.mean.get_value(), self.ZCA_mat.get_value()).reshape(s)
325 | elif isinstance(x, T.TensorVariable):
326 | return T.dot(x.flatten(2) - self.mean.dimshuffle('x',0), self.ZCA_mat).reshape(s)
327 | else:
328 | raise NotImplementedError("Whitening only implemented for numpy arrays or Theano TensorVariables")
329 |
330 | def invert(self, x):
331 | s = x.shape
332 | if isinstance(x, np.ndarray):
333 | return (np.dot(x.reshape((s[0],np.prod(s[1:]))), self.inv_ZCA_mat.get_value()) + self.mean.get_value()).reshape(s)
334 | elif isinstance(x, T.TensorVariable):
335 | return (T.dot(x.flatten(2), self.inv_ZCA_mat) + self.mean.dimshuffle('x',0)).reshape(s)
336 | else:
337 | raise NotImplementedError("Whitening only implemented for numpy arrays or Theano TensorVariables")
--------------------------------------------------------------------------------
/svhn_data.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | from six.moves import urllib
4 | from scipy.io import loadmat
5 |
6 | def maybe_download(data_dir):
7 | new_data_dir = os.path.join(data_dir, 'svhn')
8 | if not os.path.exists(new_data_dir):
9 | os.makedirs(new_data_dir)
10 | def _progress(count, block_size, total_size):
11 | sys.stdout.write('\r>> Downloading %.1f%%' % (float(count * block_size) / float(total_size) * 100.0))
12 | sys.stdout.flush()
13 | filepath, _ = urllib.request.urlretrieve('http://ufldl.stanford.edu/housenumbers/train_32x32.mat', new_data_dir+'/train_32x32.mat', _progress)
14 | filepath, _ = urllib.request.urlretrieve('http://ufldl.stanford.edu/housenumbers/test_32x32.mat', new_data_dir+'/test_32x32.mat', _progress)
15 |
16 | def load(data_dir, subset='train'):
17 | maybe_download(data_dir)
18 | if subset=='train':
19 | train_data = loadmat(os.path.join(data_dir, 'svhn') + '/train_32x32.mat')
20 | trainx = train_data['X']
21 | trainy = train_data['y'].flatten()
22 | trainy[trainy==10] = 0
23 | return trainx, trainy
24 | elif subset=='test':
25 | test_data = loadmat(os.path.join(data_dir, 'svhn') + '/test_32x32.mat')
26 | testx = test_data['X']
27 | testy = test_data['y'].flatten()
28 | testy[testy==10] = 0
29 | return testx, testy
30 | else:
31 | raise NotImplementedError('subset should be either train or test')
32 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .paramgraphics import *
2 | from .others import *
3 | from .create_ssl_data import *
--------------------------------------------------------------------------------
/utils/checkpoints.py:
--------------------------------------------------------------------------------
1 | import pickle, imp, time
2 | import gzip, logging, operator, os
3 | import os.path as osp
4 | import numpy as np
5 | from path import Path
6 | import coloredlogs
7 |
8 |
9 | def convert2dict(params):
10 | names = [par.name for par in params]
11 | assert len(names) == len(set(names))
12 |
13 | param_dict = {par.name: par.get_value() for par in params}
14 | return param_dict
15 |
16 |
17 | def save_weights(fname, params, history=None):
18 | param_dict = convert2dict(params)
19 |
20 | logging.info('saving {} parameters to {}'.format(len(params), fname))
21 | fname = Path(fname)
22 |
23 | filename, ext = osp.splitext(fname)
24 | history_file = osp.join(osp.dirname(fname), 'history.npy')
25 | np.save(history_file, history)
26 | logging.info("Save history to {}".format(history_file))
27 | if ext == '.npy':
28 | np.save(filename + '.npy', param_dict)
29 | else:
30 | f = gzip.open(fname, 'wb')
31 | pickle.dump(param_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
32 | f.close()
33 |
34 |
35 | def load_dict(fname):
36 | logging.info("Loading weights from {}".format(fname))
37 | filename, ext = os.path.splitext(fname)
38 | if ext == '.npy':
39 | params_load = np.load(fname).item()
40 | else:
41 | f = gzip.open(fname, 'r')
42 | params_load = pickle.load(f)
43 | f.close()
44 | if type(params_load) is dict:
45 | param_dict = params_load
46 | else:
47 | param_dict = convert2dict(params_load)
48 | return param_dict
49 |
50 | def load_weights_trainable(fname, l_out):
51 | import lasagne
52 | params = lasagne.layers.get_all_params(l_out, trainable=True)
53 | names = [par.name for par in params]
54 | assert len(names) == len(set(names))
55 |
56 | if type(fname) is list:
57 | param_dict = {}
58 | for name in fname:
59 | t_load = load_dict(name)
60 | param_dict.update(t_load)
61 | else:
62 | param_dict = load_dict(fname)
63 |
64 | for param in params:
65 | if param.name in param_dict:
66 | stored_shape = np.asarray(param_dict[param.name].shape)
67 | param_shape = np.asarray(param.get_value().shape)
68 | if not np.all(stored_shape == param_shape):
69 | warn_msg = 'shape mismatch:'
70 | warn_msg += '{} stored:{} new:{}'.format(
71 | param.name, stored_shape, param_shape)
72 | warn_msg += ', skipping'
73 | logging.warn(warn_msg)
74 | else:
75 | param.set_value(param_dict[param.name])
76 | else:
77 | logging.warn('unable to load parameter {} from {}: No such variable.'
78 | .format(param.name, fname))
79 |
80 |
81 |
82 | def load_weights(fname, l_out):
83 | import lasagne
84 | params = lasagne.layers.get_all_params(l_out)
85 | names = [par.name for par in params]
86 | assert len(names) == len(set(names))
87 |
88 | if type(fname) is list:
89 | param_dict = {}
90 | for name in fname:
91 | t_load = load_dict(name)
92 | param_dict.update(t_load)
93 | else:
94 | param_dict = load_dict(fname)
95 | assign_weights(params, param_dict)
96 |
97 | def assign_weights(params, param_dict):
98 | for param in params:
99 | if param.name in param_dict:
100 | stored_shape = np.asarray(param_dict[param.name].shape)
101 | param_shape = np.asarray(param.get_value().shape)
102 | if not np.all(stored_shape == param_shape):
103 | warn_msg = 'shape mismatch:'
104 | warn_msg += '{} stored:{} new:{}'.format(
105 | param.name, stored_shape, param_shape)
106 | warn_msg += ', skipping'
107 | logging.warn(warn_msg)
108 | else:
109 | param.set_value(param_dict[param.name])
110 | else:
111 | logging.warn('Unable to load parameter {}: No such variable.'
112 | .format(param.name))
113 |
114 |
115 | def get_list_name(obj):
116 | if type(obj) is list:
117 | for i in range(len(obj)):
118 | if callable(obj[i]):
119 | obj[i] = obj[i].__name__
120 | elif callable(obj):
121 | obj = obj.__name__
122 | return obj
123 |
124 |
125 | # write commandline parameters to header of logfile
126 | def build_log_file(cfg):
127 | FORMAT="%(asctime)s;%(levelname)s|%(message)s"
128 | DATEF="%H-%M-%S"
129 | logging.basicConfig(formatter=FORMAT, level=logging.DEBUG)
130 | logger = logging.getLogger()
131 | logger.setLevel(logging.DEBUG)
132 |
133 | fh = logging.FileHandler(filename=os.path.join(cfg['outfolder'], 'logfile'+time.strftime("%m-%d")+'.log'))
134 | fh.setLevel(logging.DEBUG)
135 | formatter = logging.Formatter("%(asctime)s;%(levelname)s|%(message)s", "%H:%M:%S")
136 | fh.setFormatter(formatter)
137 | logger.addHandler(fh)
138 |
139 | LEVEL_STYLES = dict(
140 | debug=dict(color='magenta'),
141 | info=dict(color='green'),
142 | verbose=dict(),
143 | warning=dict(color='blue'),
144 | error=dict(color='yellow'),
145 | critical=dict(color='red',bold=True))
146 | coloredlogs.install(level=logging.DEBUG, fmt=FORMAT, datefmt=DATEF, level_styles=LEVEL_STYLES)
147 |
148 |
149 | args_dict = cfg
150 | sorted_args = sorted(args_dict.items(), key=operator.itemgetter(0))
151 | logging.info('######################################################')
152 | logging.info('# --Configurable Parameters In this Model--')
153 | for name, val in sorted_args:
154 | logging.info("# " + name + ":\t" + str(get_list_name(val)))
155 | logging.info('######################################################')
156 |
157 |
158 | def get_cfg(args):
159 | if args.cfg is not None:
160 | cfg = imp.load_source('config', args.cfg)
161 | else:
162 | raise Exception("The file path of config_file cannot be ignored")
163 |
164 | getmodel = cfg.get_model
165 | cfg = cfg.cfg
166 | args = vars(args).items()
167 | for name, val in args:
168 | cfg[name] = val
169 |
170 | cfg['outfolder'] = os.path.join(cfg['outfolder'], cfg['name'])
171 | res_out = cfg['outfolder']
172 | if 'key_point' in cfg:
173 | res_out += '.'+ cfg['key_point']
174 | if cfg['key_point'] in cfg:
175 | res_out += '-' + str(cfg[cfg['key_point']])
176 | if 'notime' not in cfg or cfg['notime'] in [False, 'False', 'false', None, 'none', 'None']:
177 | res_out += '.' + time.strftime("%b-%d--%H-%M")
178 |
179 | res_out = os.path.realpath(res_out)
180 | if os.path.exists(res_out):
181 | tcount = 1
182 | while os.path.exists(res_out+'+'+str(tcount)):
183 | tcount += 1
184 | res_out += '+' + str(tcount)
185 |
186 | # print res_out
187 | os.makedirs(res_out)
188 | cfg['outfolder'] = res_out
189 | return cfg, getmodel
190 |
--------------------------------------------------------------------------------
/utils/create_ssl_data.py:
--------------------------------------------------------------------------------
1 | '''
2 | Create semi-supervised datasets for different models
3 | '''
4 | import numpy as np
5 |
6 | def create_ssl_data(x, y, n_classes, n_labelled, seed):
7 | # 'x': data matrix, nxk
8 | # 'y': label vector, n
9 | # 'n_classes': number of classes
10 | # 'n_labelled': number of labelled data
11 | # 'seed': random seed
12 |
13 | # check input
14 | if n_labelled%n_classes != 0:
15 | print n_labelled
16 | print n_classes
17 | raise("n_labelled (wished number of labelled samples) not divisible by n_classes (number of classes)")
18 | n_labels_per_class = n_labelled/n_classes
19 |
20 | rng = np.random.RandomState(seed)
21 | index = rng.permutation(x.shape[0])
22 | x = x[index]
23 | y = y[index]
24 |
25 | # select first several data per class
26 | data_labelled = [0]*n_classes
27 | index_labelled = []
28 | index_unlabelled = []
29 | for i in xrange(x.shape[0]):
30 | if data_labelled[y[i]] < n_labels_per_class:
31 | data_labelled[y[i]] += 1
32 | index_labelled.append(i)
33 | else:
34 | index_unlabelled.append(i)
35 |
36 | x_labelled = x[index_labelled]
37 | y_labelled = y[index_labelled]
38 | x_unlabelled = x[index_unlabelled]
39 | y_unlabelled = y[index_unlabelled]
40 | return x_labelled, y_labelled, x_unlabelled, y_unlabelled
41 |
42 |
43 | def create_ssl_data_subset(x, y, n_classes, n_labelled, n_labelled_per_time, seed):
44 | assert n_labelled%n_labelled_per_time==0
45 | times = n_labelled/n_labelled_per_time
46 | x_labelled, y_labelled, x_unlabelled, y_unlabelled = create_ssl_data(x, y, n_classes, n_labelled_per_time, seed)
47 | while (times > 1):
48 | x_labelled_new, y_labelled_new, x_unlabelled, y_unlabelled = create_ssl_data(x_unlabelled, y_unlabelled, n_classes, n_labelled_per_time, seed)
49 | x_labelled = np.vstack((x_labelled, x_labelled_new))
50 | y_labelled = np.hstack((y_labelled, y_labelled_new))
51 | times -= 1
52 | return x_labelled, y_labelled, x_unlabelled, y_unlabelled
--------------------------------------------------------------------------------
/utils/others.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import theano.tensor as T
3 | import theano, lasagne
4 |
5 |
6 | def get_pad(pad):
7 | if pad not in ['same', 'valid', 'full']:
8 | pad = tuple(map(int, pad.split('-')))
9 | return pad
10 |
11 | def get_pad_list(pad_list):
12 | re_list = []
13 | for p in pad_list:
14 | re_list.append(get_pad(p))
15 | return re_list
16 |
17 | # nonlinearities
18 | def get_nonlin(nonlin):
19 | if nonlin == 'rectify':
20 | return lasagne.nonlinearities.rectify
21 | elif nonlin == 'leaky_rectify':
22 | return lasagne.nonlinearities.LeakyRectify(0.1)
23 | elif nonlin == 'tanh':
24 | return lasagne.nonlinearities.tanh
25 | elif nonlin == 'sigmoid':
26 | return lasagne.nonlinearities.sigmoid
27 | elif nonlin == 'maxout':
28 | return 'maxout'
29 | elif nonlin == 'none':
30 | return lasagne.nonlinearities.identity
31 | else:
32 | raise ValueError('invalid non-linearity \'' + nonlin + '\'')
33 |
34 | def get_nonlin_list(nonlin_list):
35 | re_list = []
36 | for n in nonlin_list:
37 | re_list.append(get_nonlin(n))
38 | return re_list
39 |
40 | def bernoullisample(x):
41 | return np.random.binomial(1,x,size=x.shape).astype(theano.config.floatX)
42 |
43 | def array2file_2D(array,logfile):
44 | assert len(array.shape) == 2, array.shape
45 | with open(logfile,'a') as f:
46 | for i in xrange(array.shape[0]):
47 | for j in xrange(array.shape[1]):
48 | f.write(str(array[i][j])+' ')
49 | f.write('\n')
50 |
51 | def printarray_2D(array, precise=2):
52 | assert len(array.shape) == 2, array.shape
53 | format = '%.'+str(precise)+'f'
54 | for i in xrange(array.shape[0]):
55 | for j in xrange(array.shape[1]):
56 | print format %array[i][j],
57 | print
--------------------------------------------------------------------------------
/utils/paramgraphics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from scipy.misc import imsave
4 |
5 | def scale_max_min(images, max_p, min_p):
6 | # scale the images according to the max and min
7 | # images f x n, column major
8 | ret = np.zeros(images.shape)
9 | for i in xrange(images.shape[1]):
10 | # clips at first
11 | tmp = np.clip(images[:,i], min_p[i], max_p[i])
12 | # scale
13 | ret[:,i] = (tmp - min_p[i]) / (max_p[i] - min_p[i])
14 |
15 | return ret
16 |
17 | def scale_to_unit_interval(ndar, eps=1e-8):
18 | """ Scales all values in the ndarray ndar to be between 0 and 1 """
19 | ndar = ndar.copy()
20 | ndar -= ndar.min()
21 | ndar *= 1.0 / (ndar.max() + eps)
22 | return ndar
23 |
24 | def tile_raster_images(X, img_shape, tile_shape, tile_spacing=(0, 0),
25 | scale=True,
26 | output_pixel_vals=True,
27 | colorImg=False):
28 | """
29 | Transform an array with one flattened image per row, into an array in
30 | which images are reshaped and layed out like tiles on a floor.
31 |
32 | This function is useful for visualizing datasets whose rows are images,
33 | and also columns of matrices for transforming those rows
34 | (such as the first layer of a neural net).
35 |
36 | :type X: a 2-D ndarray or a tuple of 4 channels, elements of which can
37 | be 2-D ndarrays or None;
38 | :param X: a 2-D array in which every row is a flattened image.
39 |
40 | :type img_shape: tuple; (height, width)
41 | :param img_shape: the original shape of each image
42 |
43 | :type tile_shape: tuple; (rows, cols)
44 | :param tile_shape: the number of images to tile (rows, cols)
45 |
46 | :param output_pixel_vals: if output should be pixel values (i.e. int8
47 | values) or floats
48 |
49 | :param scale_rows_to_unit_interval: if the values need to be scaled before
50 | being plotted to [0,1] or not
51 |
52 |
53 | :returns: array suitable for viewing as an image.
54 | """
55 | X = X * 1.0 # converts ints to floats
56 |
57 | if colorImg:
58 | channelSize = X.shape[1]/3
59 | X = (X[:,0:channelSize], X[:,channelSize:2*channelSize], X[:,2*channelSize:3*channelSize], None)
60 |
61 | assert len(img_shape) == 2
62 | assert len(tile_shape) == 2
63 | assert len(tile_spacing) == 2
64 |
65 | # The expression below can be re-written in a more C style as
66 | # follows :
67 | #
68 | # out_shape = [0,0]
69 | # out_shape[0] = (img_shape[0] + tile_spacing[0]) * tile_shape[0] -
70 | # tile_spacing[0]
71 | # out_shape[1] = (img_shape[1] + tile_spacing[1]) * tile_shape[1] -
72 | # tile_spacing[1]
73 | out_shape = [(ishp + tsp) * tshp - tsp for ishp, tshp, tsp
74 | in zip(img_shape, tile_shape, tile_spacing)]
75 |
76 | if isinstance(X, tuple):
77 | assert len(X) == 4
78 | # Create an output np ndarray to store the image
79 | if output_pixel_vals:
80 | out_array = np.zeros((out_shape[0], out_shape[1], 4), dtype='uint8')
81 | else:
82 | out_array = np.zeros((out_shape[0], out_shape[1], 4), dtype=X.dtype)
83 |
84 | #colors default to 0, alpha defaults to 1 (opaque)
85 | if output_pixel_vals:
86 | channel_defaults = [0, 0, 0, 255]
87 | else:
88 | channel_defaults = [0., 0., 0., 1.]
89 |
90 |
91 | for i in xrange(4):
92 | if X[i] is None:
93 | # if channel is None, fill it with zeros of the correct
94 | # dtype
95 | out_array[:, :, i] = np.zeros(out_shape,
96 | dtype='uint8' if output_pixel_vals else out_array.dtype
97 | ) + channel_defaults[i]
98 | else:
99 | # use a recurrent call to compute the channel and store it
100 | # in the output
101 | xi = X[i]
102 | if scale:
103 | xi = (X[i] - X[i].min()) / (X[i].max() - X[i].min())
104 | out_array[:, :, i] = tile_raster_images(xi, img_shape, tile_shape, tile_spacing, False, output_pixel_vals)
105 |
106 |
107 | return out_array
108 |
109 | else:
110 | # if we are dealing with only one channel
111 | H, W = img_shape
112 | Hs, Ws = tile_spacing
113 |
114 | # generate a matrix to store the output
115 | out_array = np.zeros(out_shape, dtype='uint8' if output_pixel_vals else X.dtype)
116 |
117 |
118 | for tile_row in xrange(tile_shape[0]):
119 | for tile_col in xrange(tile_shape[1]):
120 | if tile_row * tile_shape[1] + tile_col < X.shape[0]:
121 | if scale:
122 | # if we should scale values to be between 0 and 1
123 | # do this by calling the `scale_to_unit_interval`
124 | # function
125 | tmp = X[tile_row * tile_shape[1] + tile_col].reshape(img_shape)
126 | this_img = scale_to_unit_interval(tmp)
127 | else:
128 | this_img = X[tile_row * tile_shape[1] + tile_col].reshape(img_shape)
129 | # add the slice to the corresponding position in the
130 | # output array
131 | out_array[
132 | tile_row * (H+Hs): tile_row * (H + Hs) + H,
133 | tile_col * (W+Ws): tile_col * (W + Ws) + W
134 | ] \
135 | = this_img * (255 if output_pixel_vals else 1)
136 | return out_array
137 |
138 | # Matrix to image
139 | def mat_to_img(w, dim_input, scale=False, colorImg=False, tile_spacing=(1,1), tile_shape=0, save_path=None):
140 | if tile_shape == 0:
141 | rowscols = int(w.shape[1]**0.5)
142 | tile_shape = (rowscols,rowscols)
143 | w = w[:, 0:rowscols*rowscols]
144 | imgs = tile_raster_images(X=w.T, img_shape=dim_input, tile_shape=tile_shape, tile_spacing=tile_spacing, scale=scale, colorImg=colorImg)
145 | if save_path is not None:
146 | imsave(save_path, imgs)
147 | return imgs
--------------------------------------------------------------------------------
/x2y_yz2x_xy2p_ssl_cifar10.py:
--------------------------------------------------------------------------------
1 | '''
2 | This code implements a triple GAN for semi-supervised learning on CIFAR10
3 | '''
4 | import os, sys, time
5 | import numpy as np
6 |
7 | import scipy
8 | from collections import OrderedDict
9 | import pickle
10 |
11 | from lasagne.layers import InputLayer, ReshapeLayer, FlattenLayer, Upscale2DLayer, MaxPool2DLayer, DropoutLayer, ConcatLayer, DenseLayer, NINLayer
12 | from lasagne.layers import GaussianNoiseLayer, Conv2DLayer, Pool2DLayer, GlobalPoolLayer, NonlinearityLayer, FeaturePoolLayer, DimshuffleLayer, ElemwiseSumLayer
13 | from lasagne.utils import floatX
14 | from zca_bn import ZCA
15 | from zca_bn import mean_only_bn as WN
16 |
17 | import gzip, os, cPickle, time, math, argparse, shutil, sys
18 |
19 | import numpy as np
20 | import theano, lasagne
21 | import theano.tensor as T
22 | import lasagne.layers as ll
23 | import lasagne.nonlinearities as ln
24 | from lasagne.layers import dnn
25 | import nn
26 | from lasagne.init import Normal
27 | from theano.sandbox.rng_mrg import MRG_RandomStreams
28 | import cifar10_data
29 |
30 | from layers.merge import ConvConcatLayer, MLPConcatLayer
31 | from layers.deconv import Deconv2DLayer
32 |
33 | from components.shortcuts import convlayer, mlplayer
34 | from components.objectives import categorical_crossentropy_ssl_separated, maximum_mean_discripancy, categorical_crossentropy, feature_matching
35 | from utils.create_ssl_data import create_ssl_data, create_ssl_data_subset
36 | from utils.others import get_nonlin_list, get_pad_list, bernoullisample, printarray_2D, array2file_2D
37 | import utils.paramgraphics as paramgraphics
38 |
39 | def build_network():
40 | conv_defs = {
41 | 'W': lasagne.init.HeNormal('relu'),
42 | 'b': lasagne.init.Constant(0.0),
43 | 'filter_size': (3, 3),
44 | 'stride': (1, 1),
45 | 'nonlinearity': lasagne.nonlinearities.LeakyRectify(0.1)
46 | }
47 |
48 | nin_defs = {
49 | 'W': lasagne.init.HeNormal('relu'),
50 | 'b': lasagne.init.Constant(0.0),
51 | 'nonlinearity': lasagne.nonlinearities.LeakyRectify(0.1)
52 | }
53 |
54 | dense_defs = {
55 | 'W': lasagne.init.HeNormal(1.0),
56 | 'b': lasagne.init.Constant(0.0),
57 | 'nonlinearity': lasagne.nonlinearities.softmax
58 | }
59 |
60 | wn_defs = {
61 | 'momentum': .999
62 | }
63 |
64 | net = InputLayer ( name='input', shape=(None, 3, 32, 32))
65 | net = GaussianNoiseLayer(net, name='noise', sigma=.15)
66 | net = WN(Conv2DLayer (net, name='conv1a', num_filters=128, pad='same', **conv_defs), **wn_defs)
67 | net = WN(Conv2DLayer (net, name='conv1b', num_filters=128, pad='same', **conv_defs), **wn_defs)
68 | net = WN(Conv2DLayer (net, name='conv1c', num_filters=128, pad='same', **conv_defs), **wn_defs)
69 | net = MaxPool2DLayer (net, name='pool1', pool_size=(2, 2))
70 | net = DropoutLayer (net, name='drop1', p=.5)
71 | net = WN(Conv2DLayer (net, name='conv2a', num_filters=256, pad='same', **conv_defs), **wn_defs)
72 | net = WN(Conv2DLayer (net, name='conv2b', num_filters=256, pad='same', **conv_defs), **wn_defs)
73 | net = WN(Conv2DLayer (net, name='conv2c', num_filters=256, pad='same', **conv_defs), **wn_defs)
74 | net = MaxPool2DLayer (net, name='pool2', pool_size=(2, 2))
75 | net = DropoutLayer (net, name='drop2', p=.5)
76 | net = WN(Conv2DLayer (net, name='conv3a', num_filters=512, pad=0, **conv_defs), **wn_defs)
77 | net = WN(NINLayer (net, name='conv3b', num_units=256, **nin_defs), **wn_defs)
78 | net = WN(NINLayer (net, name='conv3c', num_units=128, **nin_defs), **wn_defs)
79 | net = GlobalPoolLayer (net, name='pool3')
80 | net = WN(DenseLayer (net, name='dense', num_units=10, **dense_defs), **wn_defs)
81 |
82 | return net
83 |
84 | def rampup(epoch):
85 | if epoch < 80:
86 | p = max(0.0, float(epoch)) / float(80)
87 | p = 1.0 - p
88 | return math.exp(-p*p*5.0)
89 | else:
90 | return 1.0
91 |
92 | def rampdown(epoch):
93 | if epoch >= (300 - 50):
94 | ep = (epoch - (300 - 50)) * 0.5
95 | return math.exp(-(ep * ep) / 50)
96 | else:
97 | return 1.0
98 |
99 | def robust_adam(loss, params, learning_rate, beta1=0.9, beta2=0.999, epsilon=1.0e-8):
100 | # Convert NaNs to zeros.
101 | def clear_nan(x):
102 | return T.switch(T.isnan(x), np.float32(0.0), x)
103 |
104 | new = OrderedDict()
105 | pg = zip(params, lasagne.updates.get_or_compute_grads(loss, params))
106 | t = theano.shared(lasagne.utils.floatX(0.))
107 |
108 | new[t] = t + 1.0
109 | coef = learning_rate * T.sqrt(1.0 - beta2**new[t]) / (1.0 - beta1**new[t])
110 | for p, g in pg:
111 | value = p.get_value(borrow=True)
112 | m = theano.shared(np.zeros(value.shape, dtype=value.dtype), broadcastable=p.broadcastable)
113 | v = theano.shared(np.zeros(value.shape, dtype=value.dtype), broadcastable=p.broadcastable)
114 | new[m] = clear_nan(beta1 * m + (1.0 - beta1) * g)
115 | new[v] = clear_nan(beta2 * v + (1.0 - beta2) * g**2)
116 | new[p] = clear_nan(p - coef * new[m] / (T.sqrt(new[v]) + epsilon))
117 |
118 | return new
119 |
120 | '''
121 | parameters
122 | '''
123 | # global
124 | parser = argparse.ArgumentParser()
125 | parser.add_argument("-key", type=str, default=argparse.SUPPRESS)
126 | parser.add_argument("-ssl_seed", type=int, default=1)
127 | parser.add_argument("-nlabeled", type=int, default=4000)
128 | parser.add_argument("-cla_g", type=float, default=0.1)
129 | parser.add_argument("-oldmodel", type=str, default=argparse.SUPPRESS)
130 | args = parser.parse_args()
131 | args = vars(args).items()
132 | cfg = {}
133 | for name, val in args:
134 | cfg[name] = val
135 |
136 | filename_script=os.path.basename(os.path.realpath(__file__))
137 | outfolder=os.path.join("results-ssl", os.path.splitext(filename_script)[0])
138 | outfolder+='.'
139 | for item in cfg:
140 | if item is not 'oldmodel':
141 | outfolder += item+str(cfg[item])+'.'
142 | else:
143 | outfolder += 'oldmodel.'
144 | outfolder+=str(int(time.time()))
145 | if not os.path.exists(outfolder):
146 | os.makedirs(outfolder)
147 | sample_path = os.path.join(outfolder, 'sample')
148 | os.makedirs(sample_path)
149 | logfile=os.path.join(outfolder, 'logfile.log')
150 | shutil.copy(os.path.realpath(__file__), os.path.join(outfolder, filename_script))
151 | # fixed random seeds
152 | ssl_data_seed=cfg['ssl_seed']
153 | num_labelled=cfg['nlabeled']
154 | alpha_cla_g=cfg['cla_g']
155 | print ssl_data_seed, num_labelled
156 |
157 | seed=1234
158 | rng=np.random.RandomState(seed)
159 | theano_rng=MRG_RandomStreams(rng.randint(2 ** 15))
160 | lasagne.random.set_rng(np.random.RandomState(rng.randint(2 ** 15)))
161 |
162 | # flags
163 | valid_flag=False
164 | # C
165 | alpha_cla_adv = 0.01
166 | alpha_cla=1.
167 | scaled_unsup_weight_max = 100.0
168 | # G
169 | n_z=100
170 | epoch_cla_g=200
171 | # D
172 | noise_D_data=.3
173 | noise_D=.5
174 | # optimization
175 | b1_g=.5 # mom1 in Adam
176 | b1_d=.5
177 | batch_size_g=200
178 | batch_size_l_c=100
179 | batch_size_u_c=100
180 | batch_size_u_d=160
181 | batch_size_l_d=200-batch_size_u_d
182 | lr=3e-4
183 | cla_lr=3e-3
184 | num_epochs=1000
185 | anneal_lr_epoch=300
186 | anneal_lr_every_epoch=1
187 | anneal_lr_factor_cla=.99
188 | anneal_lr_factor=.995
189 | # data dependent
190 | gen_final_non=ln.tanh
191 | num_classes=10
192 | dim_input=(32,32)
193 | in_channels=3
194 | colorImg=True
195 | generation_scale=True
196 | z_generated=num_classes
197 | # evaluation
198 | vis_epoch=10
199 | eval_epoch=1
200 | batch_size_eval=200
201 |
202 |
203 | '''
204 | data
205 | '''
206 | def rescale(mat):
207 | return np.cast[theano.config.floatX](mat)
208 |
209 | train_x, train_y = cifar10_data.load('/home/chongxuan/mfs/data/cifar10/','train')
210 | eval_x, eval_y = cifar10_data.load('/home/chongxuan/mfs/data/cifar10/','test')
211 |
212 | train_y = np.int32(train_y)
213 | eval_y = np.int32(eval_y)
214 | train_x = rescale(train_x)
215 | eval_x = rescale(eval_x)
216 | x_unlabelled = train_x.copy()
217 |
218 | rng_data = np.random.RandomState(ssl_data_seed)
219 | inds = rng_data.permutation(train_x.shape[0])
220 | train_x = train_x[inds]
221 | train_y = train_y[inds]
222 | x_labelled = []
223 | y_labelled = []
224 | for j in range(num_classes):
225 | x_labelled.append(train_x[train_y==j][:num_labelled/num_classes])
226 | y_labelled.append(train_y[train_y==j][:num_labelled/num_classes])
227 | x_labelled = np.concatenate(x_labelled, axis=0)
228 | y_labelled = np.concatenate(y_labelled, axis=0)
229 | del train_x
230 |
231 | if True:
232 | print 'Size of training data', x_labelled.shape[0], x_unlabelled.shape[0]
233 | y_order = np.argsort(y_labelled)
234 | _x_mean = x_labelled[y_order]
235 | image = paramgraphics.mat_to_img(_x_mean.T, dim_input, tile_shape=(num_classes, num_labelled/num_classes), colorImg=colorImg, scale=generation_scale, save_path=os.path.join(outfolder, 'x_l_'+str(ssl_data_seed)+'_triple-gan.png'))
236 |
237 | n_batches_train_u_c = x_unlabelled.shape[0] / batch_size_u_c
238 | n_batches_train_l_c = x_labelled.shape[0] / batch_size_l_c
239 | n_batches_train_u_d = x_unlabelled.shape[0] / batch_size_u_d
240 | n_batches_train_l_d = x_labelled.shape[0] / batch_size_l_d
241 | n_batches_train_g = x_unlabelled.shape[0] / batch_size_g
242 | n_batches_eval = eval_x.shape[0] / batch_size_eval
243 |
244 |
245 | '''
246 | models
247 | '''
248 | # symbols
249 | sym_z_image = T.tile(theano_rng.uniform((z_generated, n_z)), (num_classes, 1))
250 | sym_z_rand = theano_rng.uniform(size=(batch_size_g, n_z))
251 | sym_x_u = T.tensor4()
252 | sym_x_u_d = T.tensor4()
253 | sym_x_u_g = T.tensor4()
254 | sym_x_l = T.tensor4()
255 | sym_y = T.ivector()
256 | sym_y_g = T.ivector()
257 | sym_x_eval = T.tensor4()
258 | sym_lr = T.scalar()
259 | sym_alpha_cla_g = T.scalar()
260 | sym_alpha_unlabel_entropy = T.scalar()
261 | sym_alpha_unlabel_average = T.scalar()
262 |
263 | # te
264 | sym_lr_cla = T.scalar('separate_lr')
265 | sym_x_u_rep = T.tensor4('two_pass')
266 | sym_unsup_weight = T.scalar('unsup_weight')
267 | sym_b_c = T.scalar('adam_beta1')
268 |
269 |
270 | shared_labeled = theano.shared(x_labelled, borrow=True)
271 | shared_labely = theano.shared(y_labelled, borrow=True)
272 | shared_unlabel = theano.shared(x_unlabelled, borrow=True)
273 | slice_x_u_g = T.ivector()
274 | slice_x_u_d = T.ivector()
275 | slice_x_u_c = T.ivector()
276 |
277 | classifier = build_network()
278 |
279 | # generator y2x: p_g(x, y) = p(y) p_g(x | y) where x = G(z, y), z follows p_g(z)
280 | gen_in_z = ll.InputLayer(shape=(None, n_z))
281 | gen_in_y = ll.InputLayer(shape=(None,))
282 | gen_layers = [gen_in_z]
283 | gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-00'))
284 | gen_layers.append(nn.batch_norm(ll.DenseLayer(gen_layers[-1], num_units=4*4*512, W=Normal(0.05), nonlinearity=nn.relu, name='gen-01'), g=None, name='gen-02'))
285 | gen_layers.append(ll.ReshapeLayer(gen_layers[-1], (-1,512,4,4), name='gen-03'))
286 | gen_layers.append(ConvConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-10'))
287 | gen_layers.append(nn.batch_norm(nn.Deconv2DLayer(gen_layers[-1], (None,256,8,8), (5,5), W=Normal(0.05), nonlinearity=nn.relu, name='gen-11'), g=None, name='gen-12')) # 4 -> 8
288 | gen_layers.append(ConvConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-20'))
289 | gen_layers.append(nn.batch_norm(nn.Deconv2DLayer(gen_layers[-1], (None,128,16,16), (5,5), W=Normal(0.05), nonlinearity=nn.relu, name='gen-21'), g=None, name='gen-22')) # 8 -> 16
290 | gen_layers.append(ConvConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-30'))
291 | gen_layers.append(nn.weight_norm(nn.Deconv2DLayer(gen_layers[-1], (None,3,32,32), (5,5), W=Normal(0.05), nonlinearity=gen_final_non, name='gen-31'), train_g=True, init_stdv=0.1, name='gen-32')) # 16 -> 32
292 |
293 | # discriminator xy2p: test a pair of input comes from p(x, y) instead of p_c or p_g
294 | dis_in_x = ll.InputLayer(shape=(None, in_channels) + dim_input)
295 | dis_in_y = ll.InputLayer(shape=(None,))
296 | dis_layers = [dis_in_x]
297 | dis_layers.append(ll.DropoutLayer(dis_layers[-1], p=0.2, name='dis-00'))
298 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-01'))
299 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 32, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-02'), name='dis-03'))
300 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-20'))
301 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 32, (3,3), pad=1, stride=2, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-21'), name='dis-22'))
302 | dis_layers.append(ll.DropoutLayer(dis_layers[-1], p=0.2, name='dis-23'))
303 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-30'))
304 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 64, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-31'), name='dis-32'))
305 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-40'))
306 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 64, (3,3), pad=1, stride=2, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-41'), name='dis-42'))
307 | dis_layers.append(ll.DropoutLayer(dis_layers[-1], p=0.2, name='dis-43'))
308 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-50'))
309 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 128, (3,3), pad=0, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-51'), name='dis-52'))
310 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-60'))
311 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 128, (3,3), pad=0, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-61'), name='dis-62'))
312 | dis_layers.append(ll.GlobalPoolLayer(dis_layers[-1], name='dis-63'))
313 | dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-70'))
314 | dis_layers.append(nn.weight_norm(ll.DenseLayer(dis_layers[-1], num_units=1, W=Normal(0.05), nonlinearity=ln.sigmoid, name='dis-71'), train_g=True, init_stdv=0.1, name='dis-72'))
315 |
316 |
317 | '''
318 | objectives
319 | '''
320 | # zca
321 | whitener = ZCA(x=x_unlabelled)
322 | sym_x_l_zca = whitener.apply(sym_x_l)
323 | sym_x_eval_zca = whitener.apply(sym_x_eval)
324 | sym_x_u_zca = whitener.apply(sym_x_u)
325 | sym_x_u_rep_zca = whitener.apply(sym_x_u_rep)
326 | sym_x_u_d_zca = whitener.apply(sym_x_u_d)
327 |
328 | # init
329 | lasagne.layers.get_output(classifier, sym_x_u_zca, init=True)
330 | init_updates = [u for l in lasagne.layers.get_all_layers(classifier) for u in getattr(l, 'init_updates', [])]
331 | init_fn = theano.function([sym_x_u], [], updates=init_updates)
332 |
333 | # outputs
334 | gen_out_x = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_rand}, deterministic=False)
335 | gen_out_x_zca = whitener.apply(gen_out_x)
336 | cla_out_y_l = ll.get_output(classifier, sym_x_l_zca, deterministic=False)
337 | cla_out_y_eval = ll.get_output(classifier, sym_x_eval_zca, deterministic=True)
338 | cla_out_y = ll.get_output(classifier, sym_x_u_zca, deterministic=False)
339 | cla_out_y_rep = ll.get_output(classifier, sym_x_u_rep_zca, deterministic=False)
340 | bn_updates = [u for l in lasagne.layers.get_all_layers(classifier) for u in getattr(l, 'bn_updates', [])]
341 |
342 | cla_out_y_d = ll.get_output(classifier, sym_x_u_d_zca, deterministic=False)
343 | cla_out_y_d_hard = cla_out_y_d.argmax(axis=1)
344 | cla_out_y_g = ll.get_output(classifier, gen_out_x_zca, deterministic=False)
345 |
346 | dis_out_p = ll.get_output(dis_layers[-1], {dis_in_x:T.concatenate([sym_x_l,sym_x_u_d], axis=0),dis_in_y:T.concatenate([sym_y,cla_out_y_d_hard], axis=0)}, deterministic=False)
347 | dis_out_p_g = ll.get_output(dis_layers[-1], {dis_in_x:gen_out_x,dis_in_y:sym_y_g}, deterministic=False)
348 | # argmax
349 | cla_out_y_hard = cla_out_y.argmax(axis=1)
350 | dis_out_p_c = ll.get_output(dis_layers[-1], {dis_in_x:sym_x_u,dis_in_y:cla_out_y_hard}, deterministic=False)
351 |
352 | image = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_image}, deterministic=False) # for generation
353 |
354 | accurracy_eval = (lasagne.objectives.categorical_accuracy(cla_out_y_eval, sym_y)) # for evaluation
355 | accurracy_eval = accurracy_eval.mean()
356 |
357 | # costs
358 | bce = lasagne.objectives.binary_crossentropy
359 |
360 | dis_cost_p = bce(dis_out_p, T.ones(dis_out_p.shape)).mean() # D distincts p
361 | dis_cost_p_g = bce(dis_out_p_g, T.zeros(dis_out_p_g.shape)).mean() # D distincts p_g
362 | gen_cost_p_g = bce(dis_out_p_g, T.ones(dis_out_p_g.shape)).mean() # G fools D
363 |
364 |
365 | dis_cost_p_c = bce(dis_out_p_c, T.zeros(dis_out_p_c.shape)) # D distincts p_c
366 |
367 | # argmax
368 | p_cla_max = cla_out_y.max(axis=1)
369 | cla_cost_p_c = bce(dis_out_p_c, T.ones(dis_out_p_c.shape)) # C fools D
370 | cla_cost_p_c = (cla_cost_p_c*p_cla_max).mean()
371 | dis_cost_p_c = dis_cost_p_c.mean()
372 |
373 |
374 | cla_cost_l = T.mean(lasagne.objectives.categorical_crossentropy(cla_out_y_l, sym_y), dtype=theano.config.floatX, acc_dtype=theano.config.floatX)
375 |
376 | cla_cost_u = sym_unsup_weight * T.mean(lasagne.objectives.squared_error(cla_out_y, cla_out_y_rep), dtype=theano.config.floatX, acc_dtype=theano.config.floatX)
377 |
378 | cla_cost_cla_g = categorical_crossentropy(predictions=cla_out_y_g, targets=sym_y_g)
379 |
380 | dis_cost = dis_cost_p + .5*dis_cost_p_g + .5*dis_cost_p_c
381 | gen_cost = .5*gen_cost_p_g
382 | cla_cost = alpha_cla_adv * .5 * cla_cost_p_c + alpha_cla * (cla_cost_l + cla_cost_u) + sym_alpha_cla_g * cla_cost_cla_g
383 | # fast
384 | cla_cost_fast = alpha_cla_adv * .5 * cla_cost_p_c + alpha_cla*(cla_cost_l + cla_cost_u)
385 |
386 | dis_cost_list=[dis_cost, dis_cost_p, .5*dis_cost_p_g, .5*dis_cost_p_c]
387 | gen_cost_list=[gen_cost,]
388 | cla_cost_list=[cla_cost, alpha_cla_adv * .5 * cla_cost_p_c, alpha_cla*cla_cost_l, alpha_cla*cla_cost_u, sym_alpha_cla_g*cla_cost_cla_g]
389 | # fast
390 | cla_cost_list_fast=[cla_cost_fast, alpha_cla_adv * .5 * cla_cost_p_c, alpha_cla*cla_cost_l, alpha_cla*cla_cost_u, ]
391 |
392 | # updates of D
393 | dis_params = ll.get_all_params(dis_layers, trainable=True)
394 | dis_grads = T.grad(dis_cost, dis_params)
395 | dis_updates = lasagne.updates.adam(dis_grads, dis_params, beta1=b1_d, learning_rate=sym_lr)
396 |
397 | # updates of G
398 | gen_params = ll.get_all_params(gen_layers, trainable=True)
399 | gen_grads = T.grad(gen_cost, gen_params)
400 | gen_updates = lasagne.updates.adam(gen_grads, gen_params, beta1=b1_g, learning_rate=sym_lr)
401 |
402 | # updates of C
403 | cla_params = ll.get_all_params(classifier, trainable=True)
404 | cla_updates_ = robust_adam(cla_cost, cla_params, learning_rate=sym_lr_cla, beta1=sym_b_c, beta2=.999, epsilon=1e-8)
405 |
406 | # fast updates of C
407 | cla_params = ll.get_all_params(classifier, trainable=True)
408 | cla_updates_fast_ = robust_adam(cla_cost_fast, cla_params, learning_rate=sym_lr_cla, beta1=sym_b_c, beta2=.999, epsilon=1e-8)
409 |
410 | ######## avg
411 | avg_params = lasagne.layers.get_all_params(classifier)
412 | cla_param_avg=[]
413 | for param in avg_params:
414 | value = param.get_value(borrow=True)
415 | cla_param_avg.append(theano.shared(np.zeros(value.shape, dtype=value.dtype),
416 | broadcastable=param.broadcastable,
417 | name=param.name))
418 | cla_avg_updates = [(a,a + 0.01*(p-a)) for p,a in zip(avg_params,cla_param_avg)]
419 | cla_avg_givens = [(p,a) for p,a in zip(avg_params, cla_param_avg)]
420 | cla_updates = cla_updates_.items() + bn_updates + cla_avg_updates
421 | cla_updates_fast = cla_updates_fast_.items()+ bn_updates + cla_avg_updates
422 |
423 | # functions
424 | train_batch_dis = theano.function(inputs=[sym_x_l, sym_y, sym_y_g,
425 | slice_x_u_c, slice_x_u_d, sym_lr],
426 | outputs=dis_cost_list, updates=dis_updates,
427 | givens={sym_x_u: shared_unlabel[slice_x_u_c],
428 | sym_x_u_d: shared_unlabel[slice_x_u_d]})
429 | train_batch_gen = theano.function(inputs=[sym_y_g, sym_lr],
430 | outputs=gen_cost_list, updates=gen_updates)
431 | train_batch_cla = theano.function(inputs=[sym_x_l, sym_y, sym_y_g, slice_x_u_c, sym_alpha_cla_g, sym_lr_cla, sym_b_c, sym_unsup_weight],
432 | outputs=cla_cost_list , updates=cla_updates,
433 | givens={sym_x_u: shared_unlabel[slice_x_u_c],
434 | sym_x_u_rep: shared_unlabel[slice_x_u_c]})
435 | # fast
436 | train_batch_cla_fast = theano.function(inputs=[sym_x_l, sym_y, slice_x_u_c, sym_lr_cla, sym_b_c, sym_unsup_weight],
437 | outputs=cla_cost_list_fast, updates=cla_updates_fast,
438 | givens={sym_x_u: shared_unlabel[slice_x_u_c],
439 | sym_x_u_rep: shared_unlabel[slice_x_u_c]})
440 |
441 | generate = theano.function(inputs=[sym_y_g], outputs=image)
442 | # avg
443 | evaluate = theano.function(inputs=[sym_x_eval, sym_y], outputs=[accurracy_eval], givens=cla_avg_givens)
444 |
445 | '''
446 | Load pretrained model
447 | '''
448 | if 'oldmodel' in cfg:
449 | from utils.checkpoints import load_weights
450 | load_weights(cfg['oldmodel'], dis_layers+[classifier,]+gen_layers)
451 | for (p, a) in zip(ll.get_all_params(classifier), avg_params):
452 | a.set_value(p.get_value())
453 |
454 | '''
455 | train and evaluate
456 | '''
457 |
458 | init_fn(x_unlabelled[:batch_size_u_c])
459 |
460 | print 'Start training'
461 | for epoch in range(1, 1+num_epochs):
462 | start = time.time()
463 |
464 | # randomly permute data and labels
465 | p_l = rng.permutation(x_labelled.shape[0])
466 | x_labelled = x_labelled[p_l]
467 | y_labelled = y_labelled[p_l]
468 | p_u = rng.permutation(x_unlabelled.shape[0]).astype('int32')
469 | p_u_d = rng.permutation(x_unlabelled.shape[0]).astype('int32')
470 | p_u_g = rng.permutation(x_unlabelled.shape[0]).astype('int32')
471 |
472 | dl = [0.] * len(dis_cost_list)
473 | gl = [0.] * len(gen_cost_list)
474 | cl = [0.] * len(cla_cost_list)
475 |
476 | # fast
477 |
478 | dis_t=0.
479 | gen_t=0.
480 | cla_t=0.
481 |
482 | # te
483 | rampup_value = rampup(epoch-1)
484 | rampdown_value = rampdown(epoch-1)
485 | lr_c = cla_lr
486 | b1_c = rampdown_value * 0.9 + (1.0 - rampdown_value) * 0.5
487 | unsup_weight = rampup_value * scaled_unsup_weight_max if epoch > 1 else 0.
488 |
489 | #print "@cla", lr_c, b1_c, unsup_weight
490 |
491 | for i in range(n_batches_train_u_c):
492 | from_u_c = i*batch_size_u_c
493 | to_u_c = (i+1)*batch_size_u_c
494 | i_c = i % n_batches_train_l_c
495 | from_l_c = i_c*batch_size_l_c
496 | to_l_c = (i_c+1)*batch_size_l_c
497 | i_d = i % n_batches_train_l_d
498 | from_l_d = i_d*batch_size_l_d
499 | to_l_d = (i_d+1)*batch_size_l_d
500 | i_d_ = i % n_batches_train_u_d
501 | from_u_d = i_d_*batch_size_u_d
502 | to_u_d = (i_d_+1)*batch_size_u_d
503 |
504 | sample_y = np.int32(np.repeat(np.arange(num_classes), batch_size_g/num_classes))
505 |
506 |
507 | tmp = time.time()
508 | dl_b = train_batch_dis(x_labelled[from_l_d:to_l_d], y_labelled[from_l_d:to_l_d], sample_y, p_u[from_u_c:to_u_c], p_u_d[from_u_d:to_u_d], lr)
509 | for j in xrange(len(dl)):
510 | dl[j] += dl_b[j]
511 |
512 | tmp1 = time.time()
513 |
514 | #gl_b = train_batch_gen(sample_y, p_u_g[from_u_g:to_u_g], lr)
515 | gl_b = train_batch_gen(sample_y, lr)
516 | for j in xrange(len(gl)):
517 | gl[j] += gl_b[j]
518 |
519 | tmp2 = time.time()
520 |
521 | # fast
522 | if epoch < epoch_cla_g:
523 | cl_b = train_batch_cla_fast(x_labelled[from_l_c:to_l_c], y_labelled[from_l_c:to_l_c], p_u[from_u_c:to_u_c], lr_c, b1_c, unsup_weight)
524 | cl_b += [0,]
525 | else:
526 | cl_b = train_batch_cla(x_labelled[from_l_c:to_l_c], y_labelled[from_l_c:to_l_c], sample_y, p_u[from_u_c:to_u_c], alpha_cla_g, lr_c, b1_c, unsup_weight)
527 | for j in xrange(len(cl)):
528 | cl[j] += cl_b[j]
529 |
530 | tmp3 = time.time()
531 | dis_t+=(tmp1-tmp)
532 | gen_t+=(tmp2-tmp1)
533 | cla_t+=(tmp3-tmp2)
534 |
535 | print 'dis:', dis_t, 'gen:', gen_t, 'cla:', cla_t, 'total', dis_t+gen_t+cla_t
536 |
537 | for i in xrange(len(dl)):
538 | dl[i] /= n_batches_train_u_c
539 | for i in xrange(len(gl)):
540 | gl[i] /= n_batches_train_u_c
541 | for i in xrange(len(cl)):
542 | cl[i] /= n_batches_train_u_c
543 |
544 | if (epoch >= anneal_lr_epoch) and (epoch % anneal_lr_every_epoch == 0):
545 | lr = lr*anneal_lr_factor
546 | cla_lr *= anneal_lr_factor_cla
547 |
548 | t = time.time() - start
549 |
550 | line = "*Epoch=%d Time=%.2f LR=%.5f\n" %(epoch, t, lr) + "DisLosses: " + str(dl)+"\nGenLosses: "+str(gl)+"\nClaLosses: "+str(cl)
551 |
552 | print line
553 | with open(logfile,'a') as f:
554 | f.write(line + "\n")
555 |
556 | # random generation for visualization
557 | if epoch % vis_epoch == 0:
558 | import utils.paramgraphics as paramgraphics
559 | tail = '-'+str(epoch)+'.png'
560 | ran_y = np.int32(np.repeat(np.arange(num_classes), num_classes))
561 | x_gen = generate(ran_y)
562 | x_gen = x_gen.reshape((z_generated*num_classes,-1))
563 | image = paramgraphics.mat_to_img(x_gen.T, dim_input, colorImg=colorImg, scale=generation_scale, save_path=os.path.join(sample_path, 'sample'+tail))
564 |
565 | if epoch % eval_epoch == 0:
566 | accurracy=[]
567 | for i in range(n_batches_eval):
568 | accurracy_batch = evaluate(eval_x[i*batch_size_eval:(i+1)*batch_size_eval], eval_y[i*batch_size_eval:(i+1)*batch_size_eval])
569 | accurracy += accurracy_batch
570 |
571 | accurracy=np.mean(accurracy)
572 | print ('ErrorEval=%.5f\n' % (1-accurracy,))
573 | with open(logfile,'a') as f:
574 | f.write(('ErrorEval=%.5f\n\n' % (1-accurracy,)))
575 |
576 | if epoch % 200 == 0 or (epoch == epoch_cla_g - 1):
577 | from utils.checkpoints import save_weights
578 | params = ll.get_all_params(dis_layers+[classifier,]+gen_layers)
579 | save_weights(os.path.join(outfolder, 'model_epoch' + str(epoch) + '.npy'), params, None)
580 | save_weights(os.path.join(outfolder, 'average'+ str(epoch) +'.npy'), cla_param_avg, None)
581 |
--------------------------------------------------------------------------------
/x2y_yz2x_xy2p_ssl_mnist.py:
--------------------------------------------------------------------------------
1 | '''
2 | This code implements a triple GAN for semi-supervised learning on MNIST
3 | '''
4 |
5 | import os, time, argparse, shutil
6 |
7 | import numpy as np
8 | import theano, lasagne
9 | import theano.tensor as T
10 | import lasagne.layers as ll
11 | import lasagne.nonlinearities as ln
12 | import nn
13 | from lasagne.init import Normal
14 | from theano.sandbox.rng_mrg import MRG_RandomStreams
15 | from parmesan.datasets import load_mnist_realval
16 |
17 | from layers.merge import MLPConcatLayer
18 |
19 | from components.shortcuts import convlayer
20 | from components.objectives import categorical_crossentropy_ssl_separated, categorical_crossentropy
21 | import utils.paramgraphics as paramgraphics
22 |
23 |
24 | '''
25 | parameters
26 | '''
27 | # global
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument("-key", type=str, default=argparse.SUPPRESS)
30 | parser.add_argument("-ssl_seed", type=int, default=1)
31 | parser.add_argument("-nlabeled", type=int, default=100)
32 | parser.add_argument("-objective_flag", type=str, default='argmax')
33 | parser.add_argument("-oldmodel", type=str, default=argparse.SUPPRESS)
34 | args = parser.parse_args()
35 | args = vars(args).items()
36 | cfg = {}
37 | for name, val in args:
38 | cfg[name] = val
39 | if cfg['ssl_seed'] == -1:
40 | cfg['ssl_seed'] = int(time.time())
41 |
42 | filename_script=os.path.basename(os.path.realpath(__file__))
43 | outfolder=os.path.join("results-ssl", os.path.splitext(filename_script)[0])
44 | outfolder+='.'
45 | for item in cfg:
46 | if item is not 'oldmodel':
47 | outfolder += item+str(cfg[item])+'.'
48 | else:
49 | outfolder += 'oldmodel.'
50 | outfolder+=str(int(time.time()))
51 | if not os.path.exists(outfolder):
52 | os.makedirs(outfolder)
53 | sample_path = os.path.join(outfolder, 'sample')
54 | os.makedirs(sample_path)
55 | logfile=os.path.join(outfolder, 'logfile.log')
56 | shutil.copy(os.path.realpath(__file__), os.path.join(outfolder, filename_script))
57 | # fixed random seeds
58 | ssl_data_seed=cfg['ssl_seed']
59 | num_labelled=cfg['nlabeled']
60 | print ssl_data_seed, num_labelled
61 |
62 | seed=1234
63 | rng=np.random.RandomState(seed)
64 | theano_rng=MRG_RandomStreams(rng.randint(2 ** 15))
65 | lasagne.random.set_rng(np.random.RandomState(rng.randint(2 ** 15)))
66 |
67 | # dataset
68 | data_dir='/home/chongxuan/mfs/data/mnist_real/mnist.pkl.gz'
69 | # flags
70 | valid_flag=False
71 | objective_flag = cfg['objective_flag'] # integrate y, argmax
72 |
73 | # pre-train C
74 | pre_num_epoch = 0 if num_labelled > 100 else 30
75 | pre_alpha_unlabeled_entropy=.3
76 | pre_alpha_average=.3
77 | pre_lr=3e-4
78 | pre_batch_size_lc=min(100, num_labelled)
79 | pre_batch_size_uc=500
80 | # C
81 | alpha_decay=1e-4
82 | alpha_labeled=1.
83 | alpha_unlabeled_entropy=.3
84 | alpha_average=1e-3
85 | alpha_cla=1.
86 | # G
87 | n_z=100
88 | alpha_cla_g=.1
89 | epoch_cla_g=300
90 | # D
91 | noise_D_data=.3
92 | noise_D=.5
93 | # optimization
94 | b1_g=.5 # mom1 in Adam
95 | b1_d=.5
96 | b1_c=.5
97 |
98 | # adjust batch size for different number of labeled data
99 | batch_size_g=200
100 | batch_size_l_c=min(100, num_labelled)
101 | batch_size_u_c=max(100, 10000/num_labelled)
102 | batch_size_u_d=400
103 | batch_size_l_d=max(num_labelled/100, 1)
104 | lr=1e-3
105 | num_epochs=1000
106 | anneal_lr_epoch=300
107 | anneal_lr_every_epoch=1
108 | anneal_lr_factor=.995
109 | # data dependent
110 | gen_final_non=ln.sigmoid
111 | num_classes=10
112 | dim_input=(28,28)
113 | in_channels=1
114 | colorImg=False
115 | generation_scale=False
116 | z_generated=num_classes
117 | # evaluation
118 | vis_epoch=10
119 | eval_epoch=1
120 |
121 |
122 | '''
123 | data
124 | '''
125 | train_x, train_y, valid_x, valid_y, eval_x, eval_y = load_mnist_realval(data_dir)
126 | if valid_flag:
127 | eval_x = valid_x
128 | eval_y = valid_y
129 | else:
130 | train_x = np.concatenate([train_x, valid_x])
131 | train_y = np.hstack((train_y, valid_y))
132 | train_y = np.int32(train_y)
133 | eval_y = np.int32(eval_y)
134 | train_x = train_x.astype('float32')
135 | eval_x = eval_x.astype('float32')
136 | x_unlabelled = train_x.copy()
137 |
138 | rng_data = np.random.RandomState(ssl_data_seed)
139 | inds = rng_data.permutation(train_x.shape[0])
140 | train_x = train_x[inds]
141 | train_y = train_y[inds]
142 | x_labelled = []
143 | y_labelled = []
144 | for j in range(num_classes):
145 | x_labelled.append(train_x[train_y==j][:num_labelled/num_classes])
146 | y_labelled.append(train_y[train_y==j][:num_labelled/num_classes])
147 | x_labelled = np.concatenate(x_labelled, axis=0)
148 | y_labelled = np.concatenate(y_labelled, axis=0)
149 | del train_x
150 |
151 | if True:
152 | print 'Size of training data', x_labelled.shape[0], x_unlabelled.shape[0]
153 | y_order = np.argsort(y_labelled)
154 | _x_mean = x_labelled[y_order]
155 | image = paramgraphics.mat_to_img(_x_mean.T, dim_input, tile_shape=(num_classes, num_labelled/num_classes),colorImg=colorImg, scale=generation_scale, save_path=os.path.join(outfolder, 'x_l_'+str(ssl_data_seed)+'_triple-gan.png'))
156 |
157 | pretrain_batches_train_uc = x_unlabelled.shape[0] / pre_batch_size_uc
158 | pretrain_batches_train_lc = x_labelled.shape[0] / pre_batch_size_lc
159 | n_batches_train_u_c = x_unlabelled.shape[0] / batch_size_u_c
160 | n_batches_train_l_c = x_labelled.shape[0] / batch_size_l_c
161 | n_batches_train_u_d = x_unlabelled.shape[0] / batch_size_u_d
162 | n_batches_train_l_d = x_labelled.shape[0] / batch_size_l_d
163 | n_batches_train_g = x_unlabelled.shape[0] / batch_size_g
164 |
165 |
166 | '''
167 | models
168 | '''
169 | # symbols
170 | sym_z_image = T.tile(theano_rng.uniform((z_generated, n_z)), (num_classes, 1))
171 | sym_z_rand = theano_rng.uniform(size=(batch_size_g, n_z))
172 | sym_x_u = T.matrix()
173 | sym_x_u_d = T.matrix()
174 | sym_x_u_g = T.matrix()
175 | sym_x_l = T.matrix()
176 | sym_y = T.ivector()
177 | sym_y_g = T.ivector()
178 | sym_x_eval = T.matrix()
179 | sym_lr = T.scalar()
180 | sym_alpha_cla_g = T.scalar()
181 | sym_alpha_unlabel_entropy = T.scalar()
182 | sym_alpha_unlabel_average = T.scalar()
183 |
184 | shared_unlabel = theano.shared(x_unlabelled, borrow=True)
185 | slice_x_u_g = T.ivector()
186 | slice_x_u_d = T.ivector()
187 | slice_x_u_c = T.ivector()
188 |
189 | # classifier x2y: p_c(x, y) = p(x) p_c(y | x)
190 | cla_in_x = ll.InputLayer(shape=(None, 28**2))
191 | cla_layers = [cla_in_x]
192 | cla_layers.append(ll.ReshapeLayer(cla_layers[-1], (-1,1,28,28)))
193 | cla_layers.append(convlayer(l=cla_layers[-1], bn=True, dr=0.5, ps=2, n_kerns=32, d_kerns=(5,5), pad='valid', stride=1, W=Normal(0.05), nonlinearity=ln.rectify, name='cla-1'))
194 | cla_layers.append(convlayer(l=cla_layers[-1], bn=True, dr=0, ps=1, n_kerns=64, d_kerns=(3,3), pad='same', stride=1, W=Normal(0.05), nonlinearity=ln.rectify, name='cla-2'))
195 | cla_layers.append(convlayer(l=cla_layers[-1], bn=True, dr=0.5, ps=2, n_kerns=64, d_kerns=(3,3), pad='valid', stride=1, W=Normal(0.05), nonlinearity=ln.rectify, name='cla-3'))
196 | cla_layers.append(convlayer(l=cla_layers[-1], bn=True, dr=0, ps=1, n_kerns=128, d_kerns=(3,3), pad='same', stride=1, W=Normal(0.05), nonlinearity=ln.rectify, name='cla-4'))
197 | cla_layers.append(convlayer(l=cla_layers[-1], bn=True, dr=0, ps=1, n_kerns=128, d_kerns=(3,3), pad='same', stride=1, W=Normal(0.05), nonlinearity=ln.rectify, name='cla-5'))
198 | cla_layers.append(ll.GlobalPoolLayer(cla_layers[-1]))
199 | cla_layers.append(ll.DenseLayer(cla_layers[-1], num_units=num_classes, W=lasagne.init.Normal(1e-2, 0), nonlinearity=ln.softmax, name='cla-6'))
200 | classifier = cla_layers[-1]
201 |
202 | # generator y2x: p_g(x, y) = p(y) p_g(x | y) where x = G(z, y), z follows p_g(z)
203 | gen_in_z = ll.InputLayer(shape=(None, n_z))
204 | gen_in_y = ll.InputLayer(shape=(None,))
205 | gen_layers = [gen_in_z]
206 | gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-1'))
207 | gen_layers.append(ll.batch_norm(ll.DenseLayer(gen_layers[-1], num_units=500, nonlinearity=ln.softplus, name='gen-2'), name='gen-3'))
208 | gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-4'))
209 | gen_layers.append(ll.batch_norm(ll.DenseLayer(gen_layers[-1], num_units=500, nonlinearity=ln.softplus, name='gen-5'), name='gen-6'))
210 | gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-7'))
211 | gen_layers.append(nn.l2normalize(ll.DenseLayer(gen_layers[-1], num_units=28**2, nonlinearity=gen_final_non, name='gen-8')))
212 |
213 | # discriminator xy2p: test a pair of input comes from p(x, y) instead of p_c or p_g
214 | dis_in_x = ll.InputLayer(shape=(None, 28**2))
215 | dis_in_y = ll.InputLayer(shape=(None,))
216 | dis_layers = [dis_in_x]
217 | dis_layers.append(nn.GaussianNoiseLayer(dis_layers[-1], sigma=noise_D_data, name='dis-1'))
218 | dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-2'))
219 | dis_layers.append(nn.DenseLayer(dis_layers[-1], num_units=1000, name='dis-3'))
220 | dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-4'))
221 | dis_layers.append(nn.GaussianNoiseLayer(dis_layers[-1], sigma=noise_D, name='dis-5'))
222 | dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-6'))
223 | dis_layers.append(nn.DenseLayer(dis_layers[-1], num_units=500, name='dis-7'))
224 | dis_layers.append(nn.GaussianNoiseLayer(dis_layers[-1], sigma=noise_D, name='dis-8'))
225 | dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-9'))
226 | dis_layers.append(nn.DenseLayer(dis_layers[-1], num_units=250, name='dis-10'))
227 | dis_layers.append(nn.GaussianNoiseLayer(dis_layers[-1], sigma=noise_D, name='dis-11'))
228 | dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-12'))
229 | dis_layers.append(nn.DenseLayer(dis_layers[-1], num_units=250, name='dis-13'))
230 | dis_layers.append(nn.GaussianNoiseLayer(dis_layers[-1], sigma=noise_D, name='dis-14'))
231 | dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-15'))
232 | dis_layers.append(nn.DenseLayer(dis_layers[-1], num_units=250, name='dis-16'))
233 | dis_layers.append(nn.GaussianNoiseLayer(dis_layers[-1], sigma=noise_D, name='dis-17'))
234 | dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-18'))
235 | dis_layers.append(nn.DenseLayer(dis_layers[-1], num_units=1, nonlinearity=ln.sigmoid, name='dis-19'))
236 |
237 |
238 | '''
239 | objectives
240 | '''
241 | # outputs
242 | gen_out_x = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_rand}, deterministic=False)
243 |
244 | cla_out_y_l = ll.get_output(cla_layers[-1], sym_x_l, deterministic=False)
245 | cla_out_y_eval = ll.get_output(cla_layers[-1], sym_x_eval, deterministic=True)
246 | cla_out_y = ll.get_output(cla_layers[-1], sym_x_u, deterministic=False)
247 | cla_out_y_d = ll.get_output(cla_layers[-1], {cla_in_x:sym_x_u_d}, deterministic=False)
248 | cla_out_y_d_hard = cla_out_y_d.argmax(axis=1)
249 | cla_out_y_g = ll.get_output(cla_layers[-1], {cla_in_x:gen_out_x}, deterministic=False)
250 |
251 | dis_out_p = ll.get_output(dis_layers[-1], {dis_in_x:T.concatenate([sym_x_l,sym_x_u_d], axis=0),dis_in_y:T.concatenate([sym_y,cla_out_y_d_hard], axis=0)}, deterministic=False)
252 | dis_out_p_g = ll.get_output(dis_layers[-1], {dis_in_x:gen_out_x,dis_in_y:sym_y_g}, deterministic=False)
253 |
254 | if objective_flag == 'integrate':
255 | # integrate
256 | dis_out_p_c = ll.get_output(dis_layers[-1],
257 | {dis_in_x:T.repeat(sym_x_u, num_classes, axis=0),
258 | dis_in_y:np.tile(np.arange(num_classes), batch_size_u_c)},
259 | deterministic=False)
260 | elif objective_flag == 'argmax':
261 | # argmax approximation
262 | cla_out_y_hard = cla_out_y.argmax(axis=1)
263 | dis_out_p_c = ll.get_output(dis_layers[-1], {dis_in_x:sym_x_u,dis_in_y:cla_out_y_hard}, deterministic=False)
264 | else:
265 | raise Exception('Unknown objective flags')
266 |
267 | image = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_image}, deterministic=False) # for generation
268 |
269 | accurracy_eval = (lasagne.objectives.categorical_accuracy(cla_out_y_eval, sym_y)) # for evaluation
270 | accurracy_eval = accurracy_eval.mean()
271 |
272 | # costs
273 | bce = lasagne.objectives.binary_crossentropy
274 |
275 | dis_cost_p = bce(dis_out_p, T.ones(dis_out_p.shape)).mean() # D distincts p
276 | dis_cost_p_g = bce(dis_out_p_g, T.zeros(dis_out_p_g.shape)).mean() # D distincts p_g
277 | gen_cost_p_g = bce(dis_out_p_g, T.ones(dis_out_p_g.shape)).mean() # G fools D
278 |
279 | weight_decay_classifier = lasagne.regularization.regularize_layer_params_weighted({cla_layers[-1]:1}, lasagne.regularization.l2) # weight decay
280 |
281 | dis_cost_p_c = bce(dis_out_p_c, T.zeros(dis_out_p_c.shape)) # D distincts p_c
282 | cla_cost_p_c = bce(dis_out_p_c, T.ones(dis_out_p_c.shape)) # C fools D
283 |
284 | if objective_flag == 'integrate':
285 | # integrate
286 | weight_loss_c = T.reshape(cla_cost_p_c, (-1, num_classes)) * cla_out_y
287 | cla_cost_p_c = T.sum(weight_loss_c, axis=1).mean()
288 | weight_loss_d = T.reshape(dis_cost_p_c, (-1, num_classes)) * cla_out_y
289 | dis_cost_p_c = T.sum(weight_loss_d, axis=1).mean()
290 | elif objective_flag == 'argmax':
291 | # argmax approximation
292 | p = cla_out_y.max(axis=1)
293 | cla_cost_p_c = (cla_cost_p_c*p).mean()
294 | dis_cost_p_c = dis_cost_p_c.mean()
295 |
296 | cla_cost_cla = categorical_crossentropy_ssl_separated(predictions_l=cla_out_y_l, targets=sym_y, predictions_u=cla_out_y, weight_decay=weight_decay_classifier, alpha_labeled=alpha_labeled, alpha_unlabeled=sym_alpha_unlabel_entropy, alpha_average=sym_alpha_unlabel_average, alpha_decay=alpha_decay) # classification loss
297 |
298 | pretrain_cla_loss = categorical_crossentropy_ssl_separated(predictions_l=cla_out_y_l, targets=sym_y, predictions_u=cla_out_y, weight_decay=weight_decay_classifier, alpha_labeled=alpha_labeled, alpha_unlabeled=pre_alpha_unlabeled_entropy, alpha_average=pre_alpha_average, alpha_decay=alpha_decay) # classification loss
299 | pretrain_cost = pretrain_cla_loss
300 |
301 | cla_cost_cla_g = categorical_crossentropy(predictions=cla_out_y_g, targets=sym_y_g)
302 |
303 | dis_cost = dis_cost_p + .5*dis_cost_p_g + .5*dis_cost_p_c
304 | gen_cost = .5*gen_cost_p_g
305 | # flag
306 | cla_cost = .5*cla_cost_p_c + alpha_cla*(cla_cost_cla + sym_alpha_cla_g*cla_cost_cla_g)
307 | # fast
308 | cla_cost_fast = .5*cla_cost_p_c + alpha_cla*cla_cost_cla
309 |
310 | dis_cost_list=[dis_cost, dis_cost_p, .5*dis_cost_p_g, .5*dis_cost_p_c]
311 | gen_cost_list=[gen_cost]
312 | # flag
313 | cla_cost_list=[cla_cost, .5*cla_cost_p_c, alpha_cla*cla_cost_cla, alpha_cla*sym_alpha_cla_g*cla_cost_cla_g]
314 | # fast
315 | cla_cost_list_fast=[cla_cost_fast, .5*cla_cost_p_c, alpha_cla*cla_cost_cla]
316 |
317 | # updates of D
318 | dis_params = ll.get_all_params(dis_layers, trainable=True)
319 | dis_grads = T.grad(dis_cost, dis_params)
320 | dis_updates = lasagne.updates.adam(dis_grads, dis_params, beta1=b1_d, learning_rate=sym_lr)
321 |
322 | # updates of G
323 | gen_params = ll.get_all_params(gen_layers, trainable=True)
324 | gen_grads = T.grad(gen_cost, gen_params)
325 | gen_updates = lasagne.updates.adam(gen_grads, gen_params, beta1=b1_g, learning_rate=sym_lr)
326 |
327 | # updates of C
328 | cla_params = ll.get_all_params(cla_layers, trainable=True)
329 | cla_grads = T.grad(cla_cost, cla_params)
330 | cla_updates_ = lasagne.updates.adam(cla_grads, cla_params, beta1=b1_c, learning_rate=sym_lr)
331 |
332 | # fast updates of C
333 | cla_params = ll.get_all_params(cla_layers, trainable=True)
334 | cla_grads_fast = T.grad(cla_cost_fast, cla_params)
335 | cla_updates_fast_ = lasagne.updates.adam(cla_grads_fast, cla_params, beta1=b1_c, learning_rate=sym_lr)
336 |
337 | pre_cla_grad = T.grad(pretrain_cost, cla_params)
338 | pretrain_updates_ = lasagne.updates.adam(pre_cla_grad, cla_params, beta1=0.9, beta2=0.999,
339 | epsilon=1e-8, learning_rate=pre_lr)
340 |
341 | ######## avg
342 | avg_params = lasagne.layers.get_all_params(cla_layers)
343 | cla_param_avg=[]
344 | for param in avg_params:
345 | value = param.get_value(borrow=True)
346 | cla_param_avg.append(theano.shared(np.zeros(value.shape, dtype=value.dtype),
347 | broadcastable=param.broadcastable,
348 | name=param.name))
349 | cla_avg_updates = [(a,a + 0.01*(p-a)) for p,a in zip(avg_params,cla_param_avg)]
350 | cla_avg_givens = [(p,a) for p,a in zip(avg_params, cla_param_avg)]
351 | cla_updates = cla_updates_.items() + cla_avg_updates
352 | cla_updates_fast = cla_updates_fast_.items() + cla_avg_updates
353 | pretrain_updates = pretrain_updates_.items() + cla_avg_updates
354 |
355 | # functions
356 | train_batch_dis = theano.function(inputs=[sym_x_l, sym_y, sym_y_g,
357 | slice_x_u_c, slice_x_u_d, sym_lr],
358 | outputs=dis_cost_list, updates=dis_updates,
359 | givens={sym_x_u: shared_unlabel[slice_x_u_c],
360 | sym_x_u_d: shared_unlabel[slice_x_u_d]})
361 | train_batch_gen = theano.function(inputs=[sym_y_g, sym_lr],
362 | outputs=gen_cost_list, updates=gen_updates)
363 | train_batch_cla = theano.function(inputs=[sym_x_l, sym_y, sym_y_g, slice_x_u_c, sym_alpha_cla_g, sym_lr, sym_alpha_unlabel_entropy, sym_alpha_unlabel_average],
364 | outputs=cla_cost_list, updates=cla_updates,
365 | givens={sym_x_u: shared_unlabel[slice_x_u_c]})
366 | # fast
367 | train_batch_cla_fast = theano.function(inputs=[sym_x_l, sym_y, slice_x_u_c, sym_lr, sym_alpha_unlabel_entropy, sym_alpha_unlabel_average],
368 | outputs=cla_cost_list_fast, updates=cla_updates_fast,
369 | givens={sym_x_u: shared_unlabel[slice_x_u_c]})
370 |
371 | sym_index = T.iscalar()
372 | bslice = slice(sym_index*pre_batch_size_uc, (sym_index+1)*pre_batch_size_uc)
373 | pretrain_batch_cla = theano.function(inputs=[sym_x_l, sym_y, sym_index],
374 | outputs=[pretrain_cost], updates=pretrain_updates,
375 | givens={sym_x_u: shared_unlabel[bslice]})
376 | generate = theano.function(inputs=[sym_y_g], outputs=image)
377 |
378 | # avg
379 | evaluate = theano.function(inputs=[sym_x_eval, sym_y], outputs=[accurracy_eval], givens=cla_avg_givens)
380 |
381 |
382 | '''
383 | Load pretrained model
384 | '''
385 | if 'oldmodel' in cfg:
386 | from utils.checkpoints import load_weights
387 | load_weights(cfg['oldmodel'], cla_layers)
388 | for (p, a) in zip(ll.get_all_params(cla_layers), avg_params):
389 | a.set_value(p.get_value())
390 |
391 |
392 | '''
393 | Pretrain C
394 | '''
395 | for epoch in range(1, 1+pre_num_epoch):
396 | # randomly permute data and labels
397 | p_l = rng.permutation(x_labelled.shape[0])
398 | x_labelled = x_labelled[p_l]
399 | y_labelled = y_labelled[p_l]
400 | p_u = rng.permutation(x_unlabelled.shape[0]).astype('int32')
401 | shared_unlabel.set_value(x_unlabelled[p_u])
402 |
403 | for i in range(pretrain_batches_train_uc):
404 | i_c = i % pretrain_batches_train_lc
405 | from_l_c = i_c*pre_batch_size_lc
406 | to_l_c = (i_c+1)*pre_batch_size_lc
407 |
408 | pretrain_batch_cla(x_labelled[from_l_c:to_l_c], y_labelled[from_l_c:to_l_c], i)
409 | acc = evaluate(eval_x, eval_y)
410 | acc = acc[0]
411 | print str(epoch) + ':Pretrain accuracy: ' + str(1-acc)
412 |
413 |
414 | '''
415 | train and evaluate
416 | '''
417 | for epoch in range(1, 1+num_epochs):
418 | start = time.time()
419 |
420 | # randomly permute data and labels
421 | p_l = rng.permutation(x_labelled.shape[0])
422 | x_labelled = x_labelled[p_l]
423 | y_labelled = y_labelled[p_l]
424 | p_u = rng.permutation(x_unlabelled.shape[0]).astype('int32')
425 | p_u_d = rng.permutation(x_unlabelled.shape[0]).astype('int32')
426 | p_u_g = rng.permutation(x_unlabelled.shape[0]).astype('int32')
427 |
428 | dl = [0.] * len(dis_cost_list)
429 | gl = [0.] * len(gen_cost_list)
430 | cl = [0.] * len(cla_cost_list)
431 |
432 | # fast
433 |
434 | for i in range(n_batches_train_u_c):
435 | from_u_c = i*batch_size_u_c
436 | to_u_c = (i+1)*batch_size_u_c
437 | i_c = i % n_batches_train_l_c
438 | from_l_c = i_c*batch_size_l_c
439 | to_l_c = (i_c+1)*batch_size_l_c
440 | i_d = i % n_batches_train_l_d
441 | from_l_d = i_d*batch_size_l_d
442 | to_l_d = (i_d+1)*batch_size_l_d
443 | i_d_ = i % n_batches_train_u_d
444 | from_u_d = i_d_*batch_size_u_d
445 | to_u_d = (i_d_+1)*batch_size_u_d
446 |
447 | sample_y = np.int32(np.repeat(np.arange(num_classes), batch_size_g/num_classes))
448 |
449 | dl_b = train_batch_dis(x_labelled[from_l_d:to_l_d], y_labelled[from_l_d:to_l_d], sample_y, p_u[from_u_c:to_u_c], p_u_d[from_u_d:to_u_d], lr)
450 | for j in xrange(len(dl)):
451 | dl[j] += dl_b[j]
452 |
453 | gl_b = train_batch_gen(sample_y, lr)
454 | for j in xrange(len(gl)):
455 | gl[j] += gl_b[j]
456 |
457 | # fast
458 | if epoch < epoch_cla_g:
459 | cl_b = train_batch_cla_fast(x_labelled[from_l_c:to_l_c], y_labelled[from_l_c:to_l_c], p_u[from_u_c:to_u_c], lr, alpha_unlabeled_entropy, alpha_average)
460 | cl_b += [0,]
461 | else:
462 | cl_b = train_batch_cla(x_labelled[from_l_c:to_l_c], y_labelled[from_l_c:to_l_c], sample_y, p_u[from_u_c:to_u_c], alpha_cla_g, lr, alpha_unlabeled_entropy, alpha_average)
463 |
464 | for j in xrange(len(cl)):
465 | cl[j] += cl_b[j]
466 |
467 | for i in xrange(len(dl)):
468 | dl[i] /= n_batches_train_u_c
469 | for i in xrange(len(gl)):
470 | gl[i] /= n_batches_train_u_c
471 | for i in xrange(len(cl)):
472 | cl[i] /= n_batches_train_u_c
473 |
474 | if (epoch >= anneal_lr_epoch) and (epoch % anneal_lr_every_epoch == 0):
475 | lr = lr*anneal_lr_factor
476 |
477 | t = time.time() - start
478 |
479 | line = "*Epoch=%d Time=%.2f LR=%.5f\n" %(epoch, t, lr) + "DisLosses: " + str(dl)+"\nGenLosses: "+str(gl)+"\nClaLosses: "+str(cl)
480 |
481 | print line
482 | with open(logfile,'a') as f:
483 | f.write(line + "\n")
484 |
485 | # random generation for visualization
486 | if epoch % vis_epoch == 0:
487 | import utils.paramgraphics as paramgraphics
488 | tail = '-'+str(epoch)+'.png'
489 | ran_y = np.int32(np.repeat(np.arange(num_classes), num_classes))
490 | x_gen = generate(ran_y)
491 | x_gen = x_gen.reshape((z_generated*num_classes,-1))
492 | image = paramgraphics.mat_to_img(x_gen.T, dim_input, colorImg=colorImg, scale=generation_scale, save_path=os.path.join(sample_path, 'sample'+tail))
493 |
494 | if epoch % eval_epoch == 0:
495 | acc = evaluate(eval_x, eval_y)
496 | acc = acc[0]
497 | print ('ErrorEval=%.5f\n' % (1-acc,))
498 | with open(logfile,'a') as f:
499 | f.write(('ErrorEval=%.5f\n\n' % (1-acc,)))
500 |
501 | if epoch % 200 == 0 or epoch == epoch_cla_g-1:
502 | from utils.checkpoints import save_weights
503 | params = ll.get_all_params(dis_layers+cla_layers+gen_layers)
504 | save_weights(os.path.join(outfolder, 'model_epoch' + str(epoch) + '.npy'), params, None)
505 |
506 | save_weights(os.path.join(outfolder, 'average.npy'), cla_param_avg, None)
507 |
--------------------------------------------------------------------------------
/x2y_yz2x_xy2p_ssl_svhn.py:
--------------------------------------------------------------------------------
1 | '''
2 | This code implements a triple GAN for semi-supervised learning on SVHN
3 | '''
4 | import os, time, argparse, shutil
5 |
6 | import numpy as np
7 | import theano, lasagne
8 | import theano.tensor as T
9 | import lasagne.layers as ll
10 | import lasagne.nonlinearities as ln
11 | from lasagne.layers import dnn
12 | import nn
13 | from lasagne.init import Normal
14 | from theano.sandbox.rng_mrg import MRG_RandomStreams
15 | import svhn_data
16 |
17 | from layers.merge import ConvConcatLayer, MLPConcatLayer
18 |
19 | from components.objectives import categorical_crossentropy_ssl_separated, categorical_crossentropy
20 | import utils.paramgraphics as paramgraphics
21 |
22 |
23 | '''
24 | parameters
25 | '''
26 | # global
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument("-key", type=str, default=argparse.SUPPRESS)
29 | parser.add_argument("-ssl_seed", type=int, default=1)
30 | parser.add_argument("-nlabeled", type=int, default=1000)
31 | parser.add_argument("-oldmodel", type=str, default=argparse.SUPPRESS)
32 | args = parser.parse_args()
33 | args = vars(args).items()
34 | cfg = {}
35 | for name, val in args:
36 | cfg[name] = val
37 |
38 | filename_script=os.path.basename(os.path.realpath(__file__))
39 | outfolder=os.path.join("results-ssl", os.path.splitext(filename_script)[0])
40 | outfolder+='.'
41 | for item in cfg:
42 | if item is not 'oldmodel':
43 | outfolder += item+str(cfg[item])+'.'
44 | else:
45 | outfolder += 'oldmodel.'
46 | outfolder+=str(int(time.time()))
47 | if not os.path.exists(outfolder):
48 | os.makedirs(outfolder)
49 | sample_path = os.path.join(outfolder, 'sample')
50 | os.makedirs(sample_path)
51 | logfile=os.path.join(outfolder, 'logfile.log')
52 | shutil.copy(os.path.realpath(__file__), os.path.join(outfolder, filename_script))
53 | # fixed random seeds
54 | ssl_data_seed=cfg['ssl_seed']
55 | num_labelled=cfg['nlabeled']
56 | print ssl_data_seed, num_labelled
57 |
58 | seed=1234
59 | rng=np.random.RandomState(seed)
60 | theano_rng=MRG_RandomStreams(rng.randint(2 ** 15))
61 | lasagne.random.set_rng(np.random.RandomState(rng.randint(2 ** 15)))
62 |
63 | # dataset
64 | data_dir='/home/chongxuan/mfs/data/svhn'
65 |
66 | # flags
67 | valid_flag=False
68 | objective_flag='argmax' # integrate y, choose argmax
69 |
70 | # pre-train C
71 | pre_num_epoch=0
72 | pre_alpha_unlabeled_entropy=.3
73 | pre_alpha_average=1e-3
74 | pre_lr=3e-4
75 | pre_batch_size_lc=200
76 | pre_batch_size_uc=200
77 | # C
78 | alpha_decay=1e-4
79 | alpha_labeled=1.
80 | alpha_unlabeled_entropy=.3
81 | alpha_average=1e-3
82 | alpha_cla=1.
83 | alpha_cla_adversarial=.01
84 | # G
85 | n_z=100
86 | alpha_cla_g=0.03
87 | epoch_cla_g=200
88 | # D
89 | noise_D_data=.3
90 | noise_D=.5
91 | # optimization
92 | b1_g=.5 # mom1 in Adam
93 | b1_d=.5
94 | b1_c=.5
95 | batch_size_g=200
96 | batch_size_l_c=200
97 | batch_size_u_c=200
98 | batch_size_u_d=180
99 | batch_size_l_d=20
100 | lr=3e-4
101 | num_epochs=1000
102 | anneal_lr_epoch=300
103 | anneal_lr_every_epoch=1
104 | anneal_lr_factor=.995
105 | # data dependent
106 | gen_final_non=ln.tanh
107 | num_classes=10
108 | dim_input=(32,32)
109 | in_channels=3
110 | colorImg=True
111 | generation_scale=True
112 | z_generated=num_classes
113 | # evaluation
114 | vis_epoch=10
115 | eval_epoch=1
116 | batch_size_eval=200
117 |
118 |
119 | '''
120 | data
121 | '''
122 | if valid_flag:
123 | def rescale(mat):
124 | return (((-127.5 + mat)/127.5).astype(theano.config.floatX)).reshape((-1, in_channels)+dim_input)
125 | train_x, train_y, valid_x, valid_y, eval_x, eval_y = load_svhn_small(data_dir, num_val=5000)
126 | eval_x = valid_x
127 | eval_y = valid_y
128 | else:
129 | def rescale(mat):
130 | return np.transpose(np.cast[theano.config.floatX]((-127.5 + mat)/127.5),(3,2,0,1))
131 | train_x, train_y = svhn_data.load('/home/chongxuan/mfs/data/svhn/','train')
132 | eval_x, eval_y = svhn_data.load('/home/chongxuan/mfs/data/svhn/','test')
133 |
134 | train_y = np.int32(train_y)
135 | eval_y = np.int32(eval_y)
136 | train_x = rescale(train_x)
137 | eval_x = rescale(eval_x)
138 | x_unlabelled = train_x.copy()
139 |
140 | print train_x.shape, eval_x.shape
141 |
142 | rng_data = np.random.RandomState(ssl_data_seed)
143 | inds = rng_data.permutation(train_x.shape[0])
144 | train_x = train_x[inds]
145 | train_y = train_y[inds]
146 | x_labelled = []
147 | y_labelled = []
148 | for j in range(num_classes):
149 | x_labelled.append(train_x[train_y==j][:num_labelled/num_classes])
150 | y_labelled.append(train_y[train_y==j][:num_labelled/num_classes])
151 | x_labelled = np.concatenate(x_labelled, axis=0)
152 | y_labelled = np.concatenate(y_labelled, axis=0)
153 | del train_x
154 |
155 | if True:
156 | print 'Size of training data', x_labelled.shape[0], x_unlabelled.shape[0]
157 | y_order = np.argsort(y_labelled)
158 | _x_mean = x_labelled[y_order]
159 | image = paramgraphics.mat_to_img(_x_mean.T, dim_input, tile_shape=(num_classes, num_labelled/num_classes), colorImg=colorImg, scale=generation_scale, save_path=os.path.join(outfolder, 'x_l_'+str(ssl_data_seed)+'_triple-gan.png'))
160 |
161 | pretrain_batches_train_uc = x_unlabelled.shape[0] / pre_batch_size_uc
162 | pretrain_batches_train_lc = x_labelled.shape[0] / pre_batch_size_lc
163 | n_batches_train_u_c = x_unlabelled.shape[0] / batch_size_u_c
164 | n_batches_train_l_c = x_labelled.shape[0] / batch_size_l_c
165 | n_batches_train_u_d = x_unlabelled.shape[0] / batch_size_u_d
166 | n_batches_train_l_d = x_labelled.shape[0] / batch_size_l_d
167 | n_batches_train_g = x_unlabelled.shape[0] / batch_size_g
168 | n_batches_eval = eval_x.shape[0] / batch_size_eval
169 |
170 |
171 | '''
172 | models
173 | '''
174 | # symbols
175 | sym_z_image = T.tile(theano_rng.uniform((z_generated, n_z)), (num_classes, 1))
176 | sym_z_rand = theano_rng.uniform(size=(batch_size_g, n_z))
177 | sym_x_u = T.tensor4()
178 | sym_x_u_d = T.tensor4()
179 | sym_x_u_g = T.tensor4()
180 | sym_x_l = T.tensor4()
181 | sym_y = T.ivector()
182 | sym_y_g = T.ivector()
183 | sym_x_eval = T.tensor4()
184 | sym_lr = T.scalar()
185 | sym_alpha_cla_g = T.scalar()
186 | sym_alpha_unlabel_entropy = T.scalar()
187 | sym_alpha_unlabel_average = T.scalar()
188 |
189 | shared_unlabel = theano.shared(x_unlabelled, borrow=True)
190 | slice_x_u_g = T.ivector()
191 | slice_x_u_d = T.ivector()
192 | slice_x_u_c = T.ivector()
193 |
194 | # classifier x2y: p_c(x, y) = p(x) p_c(y | x)
195 | cla_in_x = ll.InputLayer(shape=(None, in_channels) + dim_input)
196 | cla_layers = [cla_in_x]
197 | cla_layers.append(ll.DropoutLayer(cla_layers[-1], p=0.2, name='cla-00'))
198 | cla_layers.append(ll.batch_norm(dnn.Conv2DDNNLayer(cla_layers[-1], 128, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='cla-02'), name='cla-03'))
199 | cla_layers.append(ll.batch_norm(dnn.Conv2DDNNLayer(cla_layers[-1], 128, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='cla-11'), name='cla-12'))
200 | cla_layers.append(ll.batch_norm(dnn.Conv2DDNNLayer(cla_layers[-1], 128, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='cla-21'), name='cla-22'))
201 | cla_layers.append(dnn.MaxPool2DDNNLayer(cla_layers[-1], pool_size=(2, 2)))
202 | cla_layers.append(ll.DropoutLayer(cla_layers[-1], p=0.5, name='cla-23'))
203 | cla_layers.append(ll.batch_norm(dnn.Conv2DDNNLayer(cla_layers[-1], 256, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='cla-31'), name='cla-32'))
204 | cla_layers.append(ll.batch_norm(dnn.Conv2DDNNLayer(cla_layers[-1], 256, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='cla-41'), name='cla-42'))
205 | cla_layers.append(ll.batch_norm(dnn.Conv2DDNNLayer(cla_layers[-1], 256, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='cla-51'), name='cla-52'))
206 | cla_layers.append(dnn.MaxPool2DDNNLayer(cla_layers[-1], pool_size=(2, 2)))
207 | cla_layers.append(ll.DropoutLayer(cla_layers[-1], p=0.5, name='cla-53'))
208 | cla_layers.append(ll.batch_norm(dnn.Conv2DDNNLayer(cla_layers[-1], 512, (3,3), pad=0, W=Normal(0.05), nonlinearity=nn.lrelu, name='cla-61'), name='cla-62'))
209 | cla_layers.append(ll.batch_norm(ll.NINLayer(cla_layers[-1], num_units=256, W=Normal(0.05), nonlinearity=nn.lrelu, name='cla-71'), name='cla-72'))
210 | cla_layers.append(ll.batch_norm(ll.NINLayer(cla_layers[-1], num_units=128, W=Normal(0.05), nonlinearity=nn.lrelu, name='cla-81'), name='cla-82'))
211 | cla_layers.append(ll.GlobalPoolLayer(cla_layers[-1], name='cla-83'))
212 | cla_layers.append(ll.batch_norm(ll.DenseLayer(cla_layers[-1], num_units=num_classes, W=Normal(0.05), nonlinearity=ln.softmax, name='cla-91'), name='cla-92'))
213 |
214 | # generator y2x: p_g(x, y) = p(y) p_g(x | y) where x = G(z, y), z follows p_g(z)
215 | gen_in_z = ll.InputLayer(shape=(None, n_z))
216 | gen_in_y = ll.InputLayer(shape=(None,))
217 | gen_layers = [gen_in_z]
218 | gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-00'))
219 | gen_layers.append(nn.batch_norm(ll.DenseLayer(gen_layers[-1], num_units=4*4*512, W=Normal(0.05), nonlinearity=nn.relu, name='gen-01'), g=None, name='gen-02'))
220 | gen_layers.append(ll.ReshapeLayer(gen_layers[-1], (-1,512,4,4), name='gen-03'))
221 | gen_layers.append(ConvConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-10'))
222 | gen_layers.append(nn.batch_norm(nn.Deconv2DLayer(gen_layers[-1], (None,256,8,8), (5,5), W=Normal(0.05), nonlinearity=nn.relu, name='gen-11'), g=None, name='gen-12')) # 4 -> 8
223 | gen_layers.append(ConvConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-20'))
224 | gen_layers.append(nn.batch_norm(nn.Deconv2DLayer(gen_layers[-1], (None,128,16,16), (5,5), W=Normal(0.05), nonlinearity=nn.relu, name='gen-21'), g=None, name='gen-22')) # 8 -> 16
225 | gen_layers.append(ConvConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-30'))
226 | gen_layers.append(nn.weight_norm(nn.Deconv2DLayer(gen_layers[-1], (None,3,32,32), (5,5), W=Normal(0.05), nonlinearity=gen_final_non, name='gen-31'), train_g=True, init_stdv=0.1, name='gen-32')) # 16 -> 32
227 |
228 | # discriminator xy2p: test a pair of input comes from p(x, y) instead of p_c or p_g
229 | dis_in_x = ll.InputLayer(shape=(None, in_channels) + dim_input)
230 | dis_in_y = ll.InputLayer(shape=(None,))
231 | dis_layers = [dis_in_x]
232 | dis_layers.append(ll.DropoutLayer(dis_layers[-1], p=0.2, name='dis-00'))
233 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-01'))
234 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 32, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-02'), name='dis-03'))
235 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-20'))
236 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 32, (3,3), pad=1, stride=2, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-21'), name='dis-22'))
237 | dis_layers.append(ll.DropoutLayer(dis_layers[-1], p=0.2, name='dis-23'))
238 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-30'))
239 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 64, (3,3), pad=1, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-31'), name='dis-32'))
240 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-40'))
241 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 64, (3,3), pad=1, stride=2, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-41'), name='dis-42'))
242 | dis_layers.append(ll.DropoutLayer(dis_layers[-1], p=0.2, name='dis-43'))
243 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-50'))
244 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 128, (3,3), pad=0, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-51'), name='dis-52'))
245 | dis_layers.append(ConvConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-60'))
246 | dis_layers.append(nn.weight_norm(dnn.Conv2DDNNLayer(dis_layers[-1], 128, (3,3), pad=0, W=Normal(0.05), nonlinearity=nn.lrelu, name='dis-61'), name='dis-62'))
247 | dis_layers.append(ll.GlobalPoolLayer(dis_layers[-1], name='dis-63'))
248 | dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-70'))
249 | dis_layers.append(nn.weight_norm(ll.DenseLayer(dis_layers[-1], num_units=1, W=Normal(0.05), nonlinearity=ln.sigmoid, name='dis-71'), train_g=True, init_stdv=0.1, name='dis-72'))
250 |
251 |
252 | '''
253 | objectives
254 | '''
255 | # outputs
256 | gen_out_x = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_rand}, deterministic=False)
257 |
258 | cla_out_y_l = ll.get_output(cla_layers[-1], sym_x_l, deterministic=False)
259 | cla_out_y_eval = ll.get_output(cla_layers[-1], sym_x_eval, deterministic=True)
260 | cla_out_y = ll.get_output(cla_layers[-1], sym_x_u, deterministic=False)
261 | cla_out_y_d = ll.get_output(cla_layers[-1], {cla_in_x:sym_x_u_d}, deterministic=False)
262 | cla_out_y_d_hard = cla_out_y_d.argmax(axis=1)
263 | cla_out_y_g = ll.get_output(cla_layers[-1], {cla_in_x:gen_out_x}, deterministic=False)
264 |
265 | dis_out_p = ll.get_output(dis_layers[-1], {dis_in_x:T.concatenate([sym_x_l,sym_x_u_d], axis=0),dis_in_y:T.concatenate([sym_y,cla_out_y_d_hard], axis=0)}, deterministic=False)
266 | dis_out_p_g = ll.get_output(dis_layers[-1], {dis_in_x:gen_out_x,dis_in_y:sym_y_g}, deterministic=False)
267 |
268 | if objective_flag == 'integrate':
269 | # integrate
270 | dis_out_p_c = ll.get_output(dis_layers[-1],
271 | {dis_in_x:T.repeat(sym_x_u, num_classes, axis=0),
272 | dis_in_y:np.tile(np.arange(num_classes), batch_size_u_c)},
273 | deterministic=False)
274 | elif objective_flag == 'argmax':
275 | # argmax approximation
276 | cla_out_y_hard = cla_out_y.argmax(axis=1)
277 | dis_out_p_c = ll.get_output(dis_layers[-1], {dis_in_x:sym_x_u,dis_in_y:cla_out_y_hard}, deterministic=False)
278 | else:
279 | raise Exception('Unknown objective flags')
280 |
281 | image = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_image}, deterministic=False) # for generation
282 |
283 | accurracy_eval = (lasagne.objectives.categorical_accuracy(cla_out_y_eval, sym_y)) # for evaluation
284 | accurracy_eval = accurracy_eval.mean()
285 |
286 | # costs
287 | bce = lasagne.objectives.binary_crossentropy
288 |
289 | dis_cost_p = bce(dis_out_p, T.ones(dis_out_p.shape)).mean() # D distincts p
290 | dis_cost_p_g = bce(dis_out_p_g, T.zeros(dis_out_p_g.shape)).mean() # D distincts p_g
291 | gen_cost_p_g = bce(dis_out_p_g, T.ones(dis_out_p_g.shape)).mean() # G fools D
292 |
293 | weight_decay_classifier = lasagne.regularization.regularize_layer_params_weighted({cla_layers[-1]:1}, lasagne.regularization.l2) # weight decay
294 |
295 | dis_cost_p_c = bce(dis_out_p_c, T.zeros(dis_out_p_c.shape)) # D distincts p_c
296 | cla_cost_p_c = bce(dis_out_p_c, T.ones(dis_out_p_c.shape)) # C fools D
297 |
298 | if objective_flag == 'integrate':
299 | # integrate
300 | weight_loss_c = T.reshape(cla_cost_p_c, (-1, num_classes)) * cla_out_y
301 | cla_cost_p_c = T.sum(weight_loss_c, axis=1).mean()
302 | weight_loss_d = T.reshape(dis_cost_p_c, (-1, num_classes)) * cla_out_y
303 | dis_cost_p_c = T.sum(weight_loss_d, axis=1).mean()
304 | elif objective_flag == 'argmax':
305 | # argmax approximation
306 | p = cla_out_y.max(axis=1)
307 | cla_cost_p_c = (cla_cost_p_c*p).mean()
308 | dis_cost_p_c = dis_cost_p_c.mean()
309 |
310 | cla_cost_cla = categorical_crossentropy_ssl_separated(predictions_l=cla_out_y_l, targets=sym_y, predictions_u=cla_out_y, weight_decay=weight_decay_classifier, alpha_labeled=alpha_labeled, alpha_unlabeled=sym_alpha_unlabel_entropy, alpha_average=sym_alpha_unlabel_average, alpha_decay=alpha_decay) # classification loss
311 |
312 | pretrain_cla_loss = categorical_crossentropy_ssl_separated(predictions_l=cla_out_y_l, targets=sym_y, predictions_u=cla_out_y, weight_decay=weight_decay_classifier, alpha_labeled=alpha_labeled, alpha_unlabeled=pre_alpha_unlabeled_entropy, alpha_average=pre_alpha_average, alpha_decay=alpha_decay) # classification loss
313 | pretrain_cost = pretrain_cla_loss
314 |
315 | cla_cost_cla_g = categorical_crossentropy(predictions=cla_out_y_g, targets=sym_y_g)
316 |
317 | dis_cost = dis_cost_p + .5*dis_cost_p_g + .5*dis_cost_p_c
318 | gen_cost = .5*gen_cost_p_g
319 | # flag
320 | cla_cost = alpha_cla_adversarial*.5*cla_cost_p_c + cla_cost_cla + sym_alpha_cla_g*cla_cost_cla_g
321 | # fast
322 | cla_cost_fast = alpha_cla_adversarial*.5*cla_cost_p_c + cla_cost_cla
323 |
324 | dis_cost_list=[dis_cost, dis_cost_p, .5*dis_cost_p_g, .5*dis_cost_p_c]
325 | gen_cost_list=[gen_cost,]
326 | # flag
327 | cla_cost_list=[cla_cost, alpha_cla_adversarial*.5*cla_cost_p_c, cla_cost_cla, sym_alpha_cla_g*cla_cost_cla_g]
328 | # fast
329 | cla_cost_list_fast=[cla_cost_fast, alpha_cla_adversarial*.5*cla_cost_p_c, cla_cost_cla]
330 |
331 | # updates of D
332 | dis_params = ll.get_all_params(dis_layers, trainable=True)
333 | dis_grads = T.grad(dis_cost, dis_params)
334 | dis_updates = lasagne.updates.adam(dis_grads, dis_params, beta1=b1_d, learning_rate=sym_lr)
335 |
336 | # updates of G
337 | gen_params = ll.get_all_params(gen_layers, trainable=True)
338 | gen_grads = T.grad(gen_cost, gen_params)
339 | gen_updates = lasagne.updates.adam(gen_grads, gen_params, beta1=b1_g, learning_rate=sym_lr)
340 |
341 | # updates of C
342 | cla_params = ll.get_all_params(cla_layers, trainable=True)
343 | cla_grads = T.grad(cla_cost, cla_params)
344 | cla_updates_ = lasagne.updates.adam(cla_grads, cla_params, beta1=b1_c, learning_rate=sym_lr)
345 |
346 | # fast updates of C
347 | cla_params = ll.get_all_params(cla_layers, trainable=True)
348 | cla_grads_fast = T.grad(cla_cost_fast, cla_params)
349 | cla_updates_fast_ = lasagne.updates.adam(cla_grads_fast, cla_params, beta1=b1_c, learning_rate=sym_lr)
350 |
351 | pre_cla_grad = T.grad(pretrain_cost, cla_params)
352 | pretrain_updates_ = lasagne.updates.adam(pre_cla_grad, cla_params, beta1=0.9, beta2=0.999,
353 | epsilon=1e-8, learning_rate=pre_lr)
354 |
355 | ######## avg
356 | avg_params = lasagne.layers.get_all_params(cla_layers)
357 | cla_param_avg=[]
358 | for param in avg_params:
359 | value = param.get_value(borrow=True)
360 | cla_param_avg.append(theano.shared(np.zeros(value.shape, dtype=value.dtype),
361 | broadcastable=param.broadcastable,
362 | name=param.name))
363 | cla_avg_updates = [(a,a + 0.01*(p-a)) for p,a in zip(avg_params,cla_param_avg)]
364 | cla_avg_givens = [(p,a) for p,a in zip(avg_params, cla_param_avg)]
365 | cla_updates = cla_updates_.items() + cla_avg_updates
366 | cla_updates_fast = cla_updates_fast_.items() + cla_avg_updates
367 | pretrain_updates = pretrain_updates_.items() + cla_avg_updates
368 |
369 | # functions
370 | train_batch_dis = theano.function(inputs=[sym_x_l, sym_y, sym_y_g,
371 | slice_x_u_c, slice_x_u_d, sym_lr],
372 | outputs=dis_cost_list, updates=dis_updates,
373 | givens={sym_x_u: shared_unlabel[slice_x_u_c],
374 | sym_x_u_d: shared_unlabel[slice_x_u_d]})
375 | train_batch_gen = theano.function(inputs=[sym_y_g, sym_lr],
376 | outputs=gen_cost_list, updates=gen_updates)
377 | train_batch_cla = theano.function(inputs=[sym_x_l, sym_y, sym_y_g, slice_x_u_c, sym_alpha_cla_g, sym_lr, sym_alpha_unlabel_entropy, sym_alpha_unlabel_average],
378 | outputs=cla_cost_list, updates=cla_updates,
379 | givens={sym_x_u: shared_unlabel[slice_x_u_c]})
380 | # fast
381 | train_batch_cla_fast = theano.function(inputs=[sym_x_l, sym_y, slice_x_u_c, sym_lr, sym_alpha_unlabel_entropy, sym_alpha_unlabel_average],
382 | outputs=cla_cost_list_fast, updates=cla_updates_fast,
383 | givens={sym_x_u: shared_unlabel[slice_x_u_c]})
384 |
385 | sym_index = T.iscalar()
386 | bslice = slice(sym_index*pre_batch_size_uc, (sym_index+1)*pre_batch_size_uc)
387 | pretrain_batch_cla = theano.function(inputs=[sym_x_l, sym_y, sym_index],
388 | outputs=[pretrain_cost], updates=pretrain_updates,
389 | givens={sym_x_u: shared_unlabel[bslice]})
390 | generate = theano.function(inputs=[sym_y_g], outputs=image)
391 | # avg
392 | evaluate = theano.function(inputs=[sym_x_eval, sym_y], outputs=[accurracy_eval], givens=cla_avg_givens)
393 |
394 | '''
395 | Load pretrained model
396 | '''
397 | if 'oldmodel' in cfg:
398 | from utils.checkpoints import load_weights
399 | load_weights(cfg['oldmodel'], dis_layers+cla_layers+gen_layers)
400 | for (p, a) in zip(ll.get_all_params(cla_layers), avg_params):
401 | a.set_value(p.get_value())
402 |
403 |
404 | '''
405 | Pretrain C
406 | '''
407 | for epoch in range(1, 1+pre_num_epoch):
408 | # randomly permute data and labels
409 | p_l = rng.permutation(x_labelled.shape[0])
410 | x_labelled = x_labelled[p_l]
411 | y_labelled = y_labelled[p_l]
412 | p_u = rng.permutation(x_unlabelled.shape[0]).astype('int32')
413 | shared_unlabel.set_value(x_unlabelled[p_u])
414 |
415 | for i in range(pretrain_batches_train_uc):
416 | i_c = i % pretrain_batches_train_lc
417 | from_l_c = i_c*pre_batch_size_lc
418 | to_l_c = (i_c+1)*pre_batch_size_lc
419 |
420 | pretrain_batch_cla(x_labelled[from_l_c:to_l_c], y_labelled[from_l_c:to_l_c], i)
421 | accurracy=[]
422 | for i in range(n_batches_eval):
423 | accurracy_batch = evaluate(eval_x[i*batch_size_eval:(i+1)*batch_size_eval], eval_y[i*batch_size_eval:(i+1)*batch_size_eval])
424 | accurracy += accurracy_batch
425 | accurracy=np.mean(accurracy)
426 | print str(epoch) + ':Pretrain accuracy: ' + str(1-accurracy)
427 |
428 | '''
429 | train and evaluate
430 | '''
431 | print 'Start training'
432 | for epoch in range(1, 1+num_epochs):
433 | start = time.time()
434 |
435 | # randomly permute data and labels
436 | p_l = rng.permutation(x_labelled.shape[0])
437 | x_labelled = x_labelled[p_l]
438 | y_labelled = y_labelled[p_l]
439 | p_u = rng.permutation(x_unlabelled.shape[0]).astype('int32')
440 | p_u_d = rng.permutation(x_unlabelled.shape[0]).astype('int32')
441 | p_u_g = rng.permutation(x_unlabelled.shape[0]).astype('int32')
442 |
443 | dl = [0.] * len(dis_cost_list)
444 | gl = [0.] * len(gen_cost_list)
445 | cl = [0.] * len(cla_cost_list)
446 |
447 | # fast
448 | dis_t=0.
449 | gen_t=0.
450 | cla_t=0.
451 | for i in range(n_batches_train_u_c):
452 | from_u_c = i*batch_size_u_c
453 | to_u_c = (i+1)*batch_size_u_c
454 | i_c = i % n_batches_train_l_c
455 | from_l_c = i_c*batch_size_l_c
456 | to_l_c = (i_c+1)*batch_size_l_c
457 | i_d = i % n_batches_train_l_d
458 | from_l_d = i_d*batch_size_l_d
459 | to_l_d = (i_d+1)*batch_size_l_d
460 | i_d_ = i % n_batches_train_u_d
461 | from_u_d = i_d_*batch_size_u_d
462 | to_u_d = (i_d_+1)*batch_size_u_d
463 |
464 | sample_y = np.int32(np.repeat(np.arange(num_classes), batch_size_g/num_classes))
465 |
466 | tmp = time.time()
467 | dl_b = train_batch_dis(x_labelled[from_l_d:to_l_d], y_labelled[from_l_d:to_l_d], sample_y, p_u[from_u_c:to_u_c], p_u_d[from_u_d:to_u_d], lr)
468 | for j in xrange(len(dl)):
469 | dl[j] += dl_b[j]
470 |
471 | tmp1 = time.time()
472 |
473 | gl_b = train_batch_gen(sample_y, lr)
474 | for j in xrange(len(gl)):
475 | gl[j] += gl_b[j]
476 |
477 | tmp2 = time.time()
478 |
479 | # fast
480 | if epoch < epoch_cla_g:
481 | cl_b = train_batch_cla_fast(x_labelled[from_l_c:to_l_c], y_labelled[from_l_c:to_l_c], p_u[from_u_c:to_u_c], lr, alpha_unlabeled_entropy, alpha_average)
482 | cl_b += [0,]
483 | else:
484 | cl_b = train_batch_cla(x_labelled[from_l_c:to_l_c], y_labelled[from_l_c:to_l_c], sample_y, p_u[from_u_c:to_u_c], alpha_cla_g, lr, alpha_unlabeled_entropy, alpha_average)
485 |
486 | for j in xrange(len(cl)):
487 | cl[j] += cl_b[j]
488 |
489 | tmp3 = time.time()
490 | dis_t+=(tmp1-tmp)
491 | gen_t+=(tmp2-tmp1)
492 | cla_t+=(tmp3-tmp2)
493 |
494 | print 'dis:', dis_t, 'gen:', gen_t, 'cla:', cla_t, 'total', dis_t+gen_t+cla_t
495 |
496 | for i in xrange(len(dl)):
497 | dl[i] /= n_batches_train_u_c
498 | for i in xrange(len(gl)):
499 | gl[i] /= n_batches_train_u_c
500 | for i in xrange(len(cl)):
501 | cl[i] /= n_batches_train_u_c
502 |
503 | if (epoch >= anneal_lr_epoch) and (epoch % anneal_lr_every_epoch == 0):
504 | lr = lr*anneal_lr_factor
505 |
506 | t = time.time() - start
507 |
508 | line = "*Epoch=%d Time=%.2f LR=%.5f\n" %(epoch, t, lr) + "DisLosses: " + str(dl)+"\nGenLosses: "+str(gl)+"\nClaLosses: "+str(cl)
509 |
510 | print line
511 | with open(logfile,'a') as f:
512 | f.write(line + "\n")
513 |
514 | # random generation for visualization
515 | if epoch % vis_epoch == 0:
516 | import utils.paramgraphics as paramgraphics
517 | tail = '-'+str(epoch)+'.png'
518 | ran_y = np.int32(np.repeat(np.arange(num_classes), num_classes))
519 | x_gen = generate(ran_y)
520 | x_gen = x_gen.reshape((z_generated*num_classes,-1))
521 | image = paramgraphics.mat_to_img(x_gen.T, dim_input, colorImg=colorImg, scale=generation_scale, save_path=os.path.join(sample_path, 'sample'+tail))
522 |
523 | if epoch % eval_epoch == 0:
524 | accurracy=[]
525 | for i in range(n_batches_eval):
526 | accurracy_batch = evaluate(eval_x[i*batch_size_eval:(i+1)*batch_size_eval], eval_y[i*batch_size_eval:(i+1)*batch_size_eval])
527 | accurracy += accurracy_batch
528 | accurracy=np.mean(accurracy)
529 | print ('ErrorEval=%.5f\n' % (1-accurracy,))
530 | with open(logfile,'a') as f:
531 | f.write(('ErrorEval=%.5f\n\n' % (1-accurracy,)))
532 |
533 | if epoch % 200 == 0 or (epoch == epoch_cla_g - 1):
534 | from utils.checkpoints import save_weights
535 | params = ll.get_all_params(dis_layers+cla_layers+gen_layers)
536 | save_weights(os.path.join(outfolder, 'model_epoch' + str(epoch) + '.npy'), params, None)
537 | save_weights(os.path.join(outfolder, 'average'+ str(epoch) +'.npy'), cla_param_avg, None)
538 |
--------------------------------------------------------------------------------
/zca_bn.py:
--------------------------------------------------------------------------------
1 | # ZCA and MeanOnlyBNLayer implementations copied from
2 | # https://github.com/TimSalimans/weight_norm/blob/master/nn.py
3 | #
4 | # Modifications made to MeanOnlyBNLayer:
5 | # - Added configurable momentum.
6 | # - Added 'modify_incoming' flag for weight matrix sharing (not used in this project).
7 | # - Sums and means use float32 datatype.
8 |
9 | import numpy as np
10 | import theano as th
11 | import theano.tensor as T
12 | from scipy import linalg
13 | import lasagne
14 |
15 | class ZCA(object):
16 | def __init__(self, regularization=1e-5, x=None):
17 | self.regularization = regularization
18 | if x is not None:
19 | self.fit(x)
20 |
21 | def fit(self, x):
22 | s = x.shape
23 | x = x.copy().reshape((s[0],np.prod(s[1:])))
24 | m = np.mean(x, axis=0)
25 | x -= m
26 | sigma = np.dot(x.T,x) / x.shape[0]
27 | U, S, V = linalg.svd(sigma)
28 | tmp = np.dot(U, np.diag(1./np.sqrt(S+self.regularization)))
29 | tmp2 = np.dot(U, np.diag(np.sqrt(S+self.regularization)))
30 | self.ZCA_mat = th.shared(np.dot(tmp, U.T).astype(th.config.floatX))
31 | self.inv_ZCA_mat = th.shared(np.dot(tmp2, U.T).astype(th.config.floatX))
32 | self.mean = th.shared(m.astype(th.config.floatX))
33 |
34 | def apply(self, x):
35 | s = x.shape
36 | if isinstance(x, np.ndarray):
37 | return np.dot(x.reshape((s[0],np.prod(s[1:]))) - self.mean.get_value(), self.ZCA_mat.get_value()).reshape(s)
38 | elif isinstance(x, T.TensorVariable):
39 | return T.dot(x.flatten(2) - self.mean.dimshuffle('x',0), self.ZCA_mat).reshape(s)
40 | else:
41 | raise NotImplementedError("Whitening only implemented for numpy arrays or Theano TensorVariables")
42 |
43 | def invert(self, x):
44 | s = x.shape
45 | if isinstance(x, np.ndarray):
46 | return (np.dot(x.reshape((s[0],np.prod(s[1:]))), self.inv_ZCA_mat.get_value()) + self.mean.get_value()).reshape(s)
47 | elif isinstance(x, T.TensorVariable):
48 | return (T.dot(x.flatten(2), self.inv_ZCA_mat) + self.mean.dimshuffle('x',0)).reshape(s)
49 | else:
50 | raise NotImplementedError("Whitening only implemented for numpy arrays or Theano TensorVariables")
51 |
52 | # T.nnet.relu has some issues with very large inputs, this is more stable
53 | def relu(x):
54 | return T.maximum(x, 0)
55 |
56 | class MeanOnlyBNLayer(lasagne.layers.Layer):
57 | def __init__(self, incoming, b=lasagne.init.Constant(0.), g=lasagne.init.Constant(1.),
58 | W=lasagne.init.Normal(0.05), nonlinearity=relu, modify_incoming=True, momentum=0.9, **kwargs):
59 | super(MeanOnlyBNLayer, self).__init__(incoming, **kwargs)
60 | self.nonlinearity = nonlinearity
61 | self.momentum = momentum
62 | k = self.input_shape[1]
63 | if b is not None:
64 | self.b = self.add_param(b, (k,), name="b", regularizable=False)
65 | if g is not None:
66 | self.g = self.add_param(g, (k,), name="g")
67 | self.avg_batch_mean = self.add_param(lasagne.init.Constant(0.), (k,), name="avg_batch_mean", regularizable=False, trainable=False)
68 | if len(self.input_shape)==4:
69 | self.axes_to_sum = (0,2,3)
70 | self.dimshuffle_args = ['x',0,'x','x']
71 | else:
72 | self.axes_to_sum = 0
73 | self.dimshuffle_args = ['x',0]
74 |
75 | # scale weights in layer below
76 | incoming.W_param = incoming.W
77 | if modify_incoming:
78 | incoming.W_param.set_value(W.sample(incoming.W_param.get_value().shape))
79 | if incoming.W_param.ndim==4:
80 | W_axes_to_sum = (1,2,3)
81 | W_dimshuffle_args = [0,'x','x','x']
82 | else:
83 | W_axes_to_sum = 0
84 | W_dimshuffle_args = ['x',0]
85 | if g is not None:
86 | incoming.W = incoming.W_param * (self.g/T.sqrt(T.sum(T.square(incoming.W_param),axis=W_axes_to_sum,dtype=th.config.floatX, acc_dtype=th.config.floatX))).dimshuffle(*W_dimshuffle_args)
87 | else:
88 | incoming.W = incoming.W_param / T.sqrt(T.sum(T.square(incoming.W_param),axis=W_axes_to_sum,keepdims=True,dtype=th.config.floatX,acc_dtype=th.config.floatX))
89 |
90 | def get_output_for(self, input, deterministic=False, init=False, **kwargs):
91 | if deterministic:
92 | activation = input - self.avg_batch_mean.dimshuffle(*self.dimshuffle_args)
93 | else:
94 | m = T.mean(input,axis=self.axes_to_sum,dtype=th.config.floatX,acc_dtype=th.config.floatX)
95 | activation = input - m.dimshuffle(*self.dimshuffle_args)
96 | self.bn_updates = [(self.avg_batch_mean, self.momentum*self.avg_batch_mean + (1.0-self.momentum)*m)]
97 | if init:
98 | stdv = T.sqrt(T.mean(T.square(activation),axis=self.axes_to_sum,dtype=th.config.floatX,acc_dtype=th.config.floatX))
99 | activation /= stdv.dimshuffle(*self.dimshuffle_args)
100 | self.init_updates = [(self.g, self.g/stdv)]
101 | if hasattr(self, 'b'):
102 | activation += self.b.dimshuffle(*self.dimshuffle_args)
103 |
104 | return self.nonlinearity(activation)
105 |
106 | def mean_only_bn(layer, **kwargs):
107 | nonlinearity = getattr(layer, 'nonlinearity', None)
108 | if nonlinearity is not None:
109 | layer.nonlinearity = lasagne.nonlinearities.identity
110 | if hasattr(layer, 'b') and layer.b is not None:
111 | del layer.params[layer.b]
112 | layer.b = None
113 | return MeanOnlyBNLayer(layer, name=layer.name+'_n', nonlinearity=nonlinearity, **kwargs)
114 |
--------------------------------------------------------------------------------