├── README.md ├── cifar_gan.py ├── data ├── cifar10_input.py └── svhn_data.py ├── figure ├── loss_cifar.png └── test_accuracy_cifar.png ├── nn.py ├── svhn_gan.py ├── train_cifar.py └── train_svhn.py /README.md: -------------------------------------------------------------------------------- 1 | # ImprovedGAN-Tensorflow 2 | 3 | This is a simple Tensorflow implementation of the Semi-Supervised GAN proposed in the paper Improved Techniques 4 | for Training GANs from Salimans et al. https://arxiv.org/abs/1606.03498 5 | 6 | The code reproduces the results presented in the original paper. It uses the same tricks than the original Theano implementation. 7 | 8 | ## Requirements 9 | 10 | The repo supports python 3.5 + tensorflow 1.5 11 | 12 | ## Run the Code 13 | 14 | To reproduce our results on SVHN 15 | ``` 16 | python train_svhn.py 17 | ``` 18 | 19 | To reproduce our results on CIFAR-10 20 | ``` 21 | python train_cifar.py 22 | ``` 23 | 24 | ## Experiments 25 | 26 | CIFAR(% errors) | 1000 labels | 4000 labels 27 | -- | -- | -- 28 | Improved GAN (ours) | **20.24 +/- 2.17** |**17.34 +/- 1.97** 29 | 30 | SVHN(% errors) | 400 labels | 1000 labels 31 | -- | -- | -- 32 | Improved GAN (ours) | **5.22 +/- 1.02** | **4.12 +/- 1.23** 33 | -------------------------------------------------------------------------------- /cifar_gan.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import nn 3 | 4 | init_kernel = tf.random_normal_initializer(mean=0, stddev=0.05) 5 | 6 | 7 | def leakyReLu(x, alpha=0.2, name=None): 8 | if name: 9 | with tf.variable_scope(name): 10 | return _leakyReLu_impl(x, alpha) 11 | else: 12 | return _leakyReLu_impl(x, alpha) 13 | 14 | 15 | def _leakyReLu_impl(x, alpha): 16 | return tf.nn.relu(x) - (alpha * tf.nn.relu(-x)) 17 | 18 | def gaussian_noise_layer(input_layer, std): 19 | noise = tf.random_normal(shape=tf.shape(input_layer), mean=0.0, stddev=std, dtype=tf.float32) 20 | return input_layer + noise 21 | 22 | 23 | def discriminator(inp, is_training, init=False, reuse=False, getter =None): 24 | with tf.variable_scope('discriminator_model', reuse=reuse,custom_getter=getter): 25 | counter = {} 26 | x = tf.reshape(inp, [-1, 32, 32, 3]) 27 | 28 | x = tf.layers.dropout(x, rate=0.2, training=is_training, name='dropout_0') 29 | 30 | x = nn.conv2d(x, 96, nonlinearity=leakyReLu, init=init, counters=counter) 31 | x = nn.conv2d(x, 96, nonlinearity=leakyReLu, init=init, counters=counter) 32 | x = nn.conv2d(x, 96, stride=[2, 2], nonlinearity=leakyReLu, init=init, counters=counter) 33 | 34 | x = tf.layers.dropout(x, rate=0.5, training=is_training, name='dropout_1') 35 | 36 | x = nn.conv2d(x, 192, nonlinearity=leakyReLu, init=init, counters=counter) 37 | x = nn.conv2d(x, 192, nonlinearity=leakyReLu, init=init, counters=counter) 38 | x = nn.conv2d(x, 192, stride=[2, 2], nonlinearity=leakyReLu, init=init, counters=counter) 39 | 40 | x = tf.layers.dropout(x, rate=0.5, training=is_training, name='dropout_2') 41 | 42 | x = nn.conv2d(x, 192, pad='VALID', nonlinearity=leakyReLu, init=init, counters=counter) 43 | x = nn.nin(x, 192, counters=counter, nonlinearity=leakyReLu, init=init) 44 | x = nn.nin(x, 192, counters=counter, nonlinearity=leakyReLu, init=init) 45 | x = tf.layers.max_pooling2d(x, pool_size=6, strides=1, 46 | name='avg_pool_0') 47 | x = tf.squeeze(x, [1, 2]) 48 | 49 | intermediate_layer = x 50 | 51 | logits = nn.dense(x, 10, nonlinearity=None, init=init, counters=counter, init_scale=0.1) 52 | 53 | return logits, intermediate_layer 54 | 55 | 56 | def generator(z_seed, is_training, init=False,reuse=False): 57 | with tf.variable_scope('generator_model', reuse=reuse): 58 | counter = {} 59 | x = z_seed 60 | with tf.variable_scope('dense_1'): 61 | x = tf.layers.dense(x, units=4 * 4 * 512, kernel_initializer=init_kernel) 62 | x = tf.layers.batch_normalization(x, training=is_training, name='batchnorm_1') 63 | x = tf.nn.relu(x) 64 | 65 | x = tf.reshape(x, [-1, 4, 4, 512]) 66 | 67 | with tf.variable_scope('deconv_1'): 68 | x = tf.layers.conv2d_transpose(x, 256, [5, 5], strides=[2, 2], padding='SAME', kernel_initializer=init_kernel) 69 | x = tf.layers.batch_normalization(x, training=is_training, name='batchnorm_2') 70 | x = tf.nn.relu(x) 71 | 72 | with tf.variable_scope('deconv_2'): 73 | x = tf.layers.conv2d_transpose(x, 128, [5, 5], strides=[2, 2], padding='SAME', kernel_initializer=init_kernel) 74 | x = tf.layers.batch_normalization(x, training=is_training, name='batchnormn_3') 75 | x = tf.nn.relu(x) 76 | 77 | with tf.variable_scope('deconv_3'): 78 | output = nn.deconv2d(x, num_filters=3, filter_size=[5, 5], stride=[2, 2], nonlinearity=tf.tanh, init=init, 79 | counters=counter, init_scale=0.1) 80 | return output 81 | -------------------------------------------------------------------------------- /data/cifar10_input.py: -------------------------------------------------------------------------------- 1 | # CIFAR10 Downloader 2 | 3 | import pickle 4 | import os 5 | import errno 6 | import tarfile 7 | import shutil 8 | import numpy as np 9 | 10 | import urllib3 11 | 12 | 13 | _shuffle = True 14 | 15 | 16 | def _unpickle_file(filename): 17 | print("Loading pickle file: {}".format(filename)) 18 | 19 | with open(filename, mode='rb') as file: 20 | data = pickle.load(file, encoding='bytes') 21 | 22 | # Reorder the data 23 | img = data[b'data'] 24 | img = img.reshape([-1, 3, 32, 32]) 25 | img = img.transpose([0, 2, 3, 1]) 26 | # Load labels 27 | lbl = np.array(data[b'labels']) 28 | 29 | return img, lbl 30 | 31 | 32 | def _get_dataset(path,split): 33 | assert split == "test" or split == "train" 34 | dirname = "cifar-10-batches-py" 35 | # data_url = "http://10.217.128.198/datasets/cifar-10-python.tar.gz" 36 | # data_url = "http://10.217.128.198/mnt/data_c/datasets/cifar-10-python.tar.gz" 37 | data_url = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' 38 | if not os.path.exists(os.path.join(path, dirname)): 39 | # Extract or download data 40 | try: 41 | os.makedirs(path) 42 | except OSError as exception: 43 | if exception.errno != errno.EEXIST: 44 | raise 45 | 46 | file_path = os.path.join(path, data_url.split('/')[-1]) 47 | if not os.path.exists(file_path): 48 | # Download 49 | print("Downloading {}".format(data_url)) 50 | with urllib3.PoolManager().request('GET', data_url, preload_content=False) as r, \ 51 | open(file_path, 'wb') as w: 52 | shutil.copyfileobj(r, w) 53 | 54 | print("Unpacking {}".format(file_path)) 55 | # Unpack data 56 | tarfile.open(name=file_path, mode="r:gz").extractall(path) 57 | 58 | # Import the data 59 | filenames = ["test_batch"] if split == "test" else \ 60 | ["data_batch_{}".format(i) for i in range(1, 6)] 61 | 62 | imgs = [] 63 | lbls = [] 64 | for f in filenames: 65 | img, lbl = _unpickle_file(os.path.join(path, dirname, f)) 66 | imgs.append(img) 67 | lbls.append(lbl) 68 | 69 | # Now we flatten the arrays 70 | imgs = np.concatenate(imgs) 71 | lbls = np.concatenate(lbls) 72 | 73 | # Convert images to [0..1] range 74 | imgs = imgs.astype(np.float32) / 255.0 75 | imgs = imgs *2. -1. 76 | # Convert images to [-1..1] range 77 | # imgs = (imgs.astype(np.float32)-127.5) / 128. 78 | # Convert label to one hot encoding 79 | # lbl = np.zeros((len(lbls),10)) #lbl s !! 80 | # lbl[np.arange(len(lbls)), lbls] = 1 81 | return imgs, lbls.astype(np.uint8) 82 | 83 | -------------------------------------------------------------------------------- /data/svhn_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from six.moves import urllib 4 | from scipy.io import loadmat 5 | 6 | def maybe_download(data_dir): 7 | new_data_dir = os.path.join(data_dir, 'svhn') 8 | if not os.path.exists(new_data_dir): 9 | os.makedirs(new_data_dir) 10 | def _progress(count, block_size, total_size): 11 | sys.stdout.write('\r>> Downloading %.1f%%' % (float(count * block_size) / float(total_size) * 100.0)) 12 | sys.stdout.flush() 13 | filepath, _ = urllib.request.urlretrieve('http://ufldl.stanford.edu/housenumbers/train_32x32.mat', new_data_dir+'/train_32x32.mat', _progress) 14 | filepath, _ = urllib.request.urlretrieve('http://ufldl.stanford.edu/housenumbers/test_32x32.mat', new_data_dir+'/test_32x32.mat', _progress) 15 | 16 | def load(data_dir, subset='train'): 17 | maybe_download(data_dir) 18 | if subset=='train': 19 | train_data = loadmat(os.path.join(data_dir, 'svhn') + '/train_32x32.mat') 20 | trainx = train_data['X'] 21 | trainy = train_data['y'].flatten() 22 | trainy[trainy==10] = 0 23 | return trainx, trainy 24 | elif subset=='test': 25 | test_data = loadmat(os.path.join(data_dir, 'svhn') + '/test_32x32.mat') 26 | testx = test_data['X'] 27 | testy = test_data['y'].flatten() 28 | testy[testy==10] = 0 29 | return testx, testy 30 | else: 31 | raise NotImplementedError('subset should be either train or test') 32 | -------------------------------------------------------------------------------- /figure/loss_cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bruno-31/ImprovedGAN-Tensorflow/6f4c21020ec13e69f8bb6a8b7612ac84c05891e0/figure/loss_cifar.png -------------------------------------------------------------------------------- /figure/test_accuracy_cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bruno-31/ImprovedGAN-Tensorflow/6f4c21020ec13e69f8bb6a8b7612ac84c05891e0/figure/test_accuracy_cifar.png -------------------------------------------------------------------------------- /nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various tensorflow utilities 3 | Function taken from the repo: https://github.com/openai/weightnorm 4 | This repo contains example code for Weight Normalization, as described in their paper: 5 | Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks, by Tim Salimans, and Diederik P. Kingma. 6 | """ 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | from tensorflow.contrib.framework.python.ops import add_arg_scope 11 | 12 | def int_shape(x): 13 | return list(map(int, x.get_shape())) 14 | 15 | def concat_elu(x): 16 | """ like concatenated ReLU (http://arxiv.org/abs/1603.05201), but then with ELU """ 17 | axis = len(x.get_shape())-1 18 | return tf.nn.elu(tf.concat(axis, [x, -x])) 19 | 20 | def log_sum_exp(x): 21 | """ numerically stable log_sum_exp implementation that prevents overflow """ 22 | axis = len(x.get_shape())-1 23 | m = tf.reduce_max(x, axis) 24 | m2 = tf.reduce_max(x, axis, keep_dims=True) 25 | return m + tf.log(tf.reduce_sum(tf.exp(x-m2), axis)) 26 | 27 | def log_prob_from_logits(x): 28 | """ numerically stable log_softmax implementation that prevents overflow """ 29 | axis = len(x.get_shape())-1 30 | m = tf.reduce_max(x, axis, keep_dims=True) 31 | return x - m - tf.log(tf.reduce_sum(tf.exp(x-m), axis, keep_dims=True)) 32 | 33 | def discretized_mix_logistic_loss(x,l,sum_all=True): 34 | """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """ 35 | xs = int_shape(x) # true image (i.e. labels) to regress to, e.g. (B,32,32,3) 36 | ls = int_shape(l) # predicted distribution, e.g. (B,32,32,100) 37 | nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics 38 | logit_probs = l[:,:,:,:nr_mix] 39 | l = tf.reshape(l[:,:,:,nr_mix:], xs + [nr_mix*3]) 40 | means = l[:,:,:,:,:nr_mix] 41 | log_scales = tf.maximum(l[:,:,:,:,nr_mix:2*nr_mix], -7.) 42 | coeffs = tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix]) 43 | x = tf.reshape(x, xs + [1]) + tf.zeros(xs + [nr_mix]) # here and below: getting the means and adjusting them based on preceding sub-pixels 44 | m2 = tf.reshape(means[:,:,:,1,:] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0],xs[1],xs[2],1,nr_mix]) 45 | m3 = tf.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0],xs[1],xs[2],1,nr_mix]) 46 | means = tf.concat(3,[tf.reshape(means[:,:,:,0,:], [xs[0],xs[1],xs[2],1,nr_mix]), m2, m3]) 47 | centered_x = x - means 48 | inv_stdv = tf.exp(-log_scales) 49 | plus_in = inv_stdv * (centered_x + 1./255.) 50 | cdf_plus = tf.nn.sigmoid(plus_in) 51 | min_in = inv_stdv * (centered_x - 1./255.) 52 | cdf_min = tf.nn.sigmoid(min_in) 53 | log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling) 54 | log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling) 55 | cdf_delta = cdf_plus - cdf_min # probability for all other cases 56 | mid_in = inv_stdv * centered_x 57 | log_pdf_mid = mid_in - log_scales - 2.*tf.nn.softplus(mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code) 58 | 59 | # now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen for us) 60 | 61 | # this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select() 62 | # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta))) 63 | 64 | # robust version, that still works if probabilities are below 1e-5 (which never happens in our code) 65 | # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs 66 | # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue 67 | # if the probability on a sub-pixel is below 1e-5, we use an approximation based on the assumption that the log-density is constant in the bin of the observed sub-pixel value 68 | log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.select(cdf_delta > 1e-5, tf.log(tf.maximum(cdf_delta, 1e-12)), log_pdf_mid - np.log(127.5)))) 69 | 70 | log_probs = tf.reduce_sum(log_probs,3) + log_prob_from_logits(logit_probs) 71 | if sum_all: 72 | return -tf.reduce_sum(log_sum_exp(log_probs)) 73 | else: 74 | return -tf.reduce_sum(log_sum_exp(log_probs),[1,2]) 75 | 76 | def sample_from_discretized_mix_logistic(l,nr_mix): 77 | ls = int_shape(l) 78 | xs = ls[:-1] + [3] 79 | # unpack parameters 80 | logit_probs = l[:, :, :, :nr_mix] 81 | l = tf.reshape(l[:, :, :, nr_mix:], xs + [nr_mix*3]) 82 | # sample mixture indicator from softmax 83 | sel = tf.one_hot(tf.argmax(logit_probs - tf.log(-tf.log(tf.random_uniform(logit_probs.get_shape(), minval=1e-5, maxval=1. - 1e-5))), 3), depth=nr_mix, dtype=tf.float32) 84 | sel = tf.reshape(sel, xs[:-1] + [1,nr_mix]) 85 | # select logistic parameters 86 | means = tf.reduce_sum(l[:,:,:,:,:nr_mix]*sel,4) 87 | log_scales = tf.maximum(tf.reduce_sum(l[:,:,:,:,nr_mix:2*nr_mix]*sel,4), -7.) 88 | coeffs = tf.reduce_sum(tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])*sel,4) 89 | # sample from logistic & clip to interval 90 | # we don't actually round to the nearest 8bit value when sampling 91 | u = tf.random_uniform(means.get_shape(), minval=1e-5, maxval=1. - 1e-5) 92 | x = means + tf.exp(log_scales)*(tf.log(u) - tf.log(1. - u)) 93 | x0 = tf.minimum(tf.maximum(x[:,:,:,0], -1.), 1.) 94 | x1 = tf.minimum(tf.maximum(x[:,:,:,1] + coeffs[:,:,:,0]*x0, -1.), 1.) 95 | x2 = tf.minimum(tf.maximum(x[:,:,:,2] + coeffs[:,:,:,1]*x0 + coeffs[:,:,:,2]*x1, -1.), 1.) 96 | return tf.concat(3,[tf.reshape(x0,xs[:-1]+[1]), tf.reshape(x1,xs[:-1]+[1]), tf.reshape(x2,xs[:-1]+[1])]) 97 | 98 | def get_var_maybe_avg(var_name, ema, **kwargs): 99 | ''' utility for retrieving polyak averaged params ''' 100 | v = tf.get_variable(var_name, **kwargs) 101 | if ema is not None: 102 | v = ema.average(v) 103 | return v 104 | 105 | def get_vars_maybe_avg(var_names, ema, **kwargs): 106 | ''' utility for retrieving polyak averaged params ''' 107 | vars = [] 108 | for vn in var_names: 109 | vars.append(get_var_maybe_avg(vn, ema, **kwargs)) 110 | return vars 111 | 112 | def adam_updates(params, cost_or_grads, lr=0.001, mom1=0.9, mom2=0.999): 113 | ''' Adam optimizer ''' 114 | updates = [] 115 | if type(cost_or_grads) is not list: 116 | grads = tf.gradients(cost_or_grads, params) 117 | else: 118 | grads = cost_or_grads 119 | t = tf.Variable(1., 'adam_t') 120 | for p, g in zip(params, grads): 121 | mg = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_mg') 122 | if mom1>0: 123 | v = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_v') 124 | v_t = mom1*v + (1. - mom1)*g 125 | v_hat = v_t / (1. - tf.pow(mom1,t)) 126 | updates.append(v.assign(v_t)) 127 | else: 128 | v_hat = g 129 | mg_t = mom2*mg + (1. - mom2)*tf.square(g) 130 | mg_hat = mg_t / (1. - tf.pow(mom2,t)) 131 | g_t = v_hat / tf.sqrt(mg_hat + 1e-8) 132 | p_t = p - lr * g_t 133 | updates.append(mg.assign(mg_t)) 134 | updates.append(p.assign(p_t)) 135 | updates.append(t.assign_add(1)) 136 | return tf.group(*updates) 137 | 138 | def get_name(layer_name, counters): 139 | ''' utlity for keeping track of layer names ''' 140 | if not layer_name in counters: 141 | counters[layer_name] = 0 142 | name = layer_name + '_' + str(counters[layer_name]) 143 | counters[layer_name] += 1 144 | return name 145 | 146 | @add_arg_scope 147 | def dense(x, num_units, nonlinearity=None, init_scale=1., counters={},init=False, ema=None, train_scale=True, init_w=tf.random_normal_initializer(0, 0.05),**kwargs): 148 | ''' fully connected layer ''' 149 | name = get_name('dense', counters) 150 | with tf.variable_scope(name): 151 | if init: 152 | # data based initialization of parameters 153 | V = tf.get_variable('V', [int(x.get_shape()[1]),num_units], tf.float32, init_w, trainable=True) 154 | V_norm = tf.nn.l2_normalize(V.initialized_value(), [0]) 155 | x_init = tf.matmul(x, V_norm) 156 | m_init, v_init = tf.nn.moments(x_init, [0]) 157 | scale_init = init_scale/tf.sqrt(v_init + 1e-10) 158 | # g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=train_scale) 159 | # b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init*scale_init, trainable=True) 160 | g = tf.get_variable('g', dtype=tf.float32, initializer=tf.constant(np.ones(num_units),tf.float32), trainable=train_scale) 161 | b = tf.get_variable('b', dtype=tf.float32, initializer=tf.constant(np.zeros(num_units),tf.float32), trainable=True) 162 | x_init = tf.reshape(scale_init,[1,num_units])*(x_init-tf.reshape(m_init,[1,num_units])) 163 | if nonlinearity is not None: 164 | x_init = nonlinearity(x_init) 165 | return x_init 166 | 167 | else: 168 | V,g,b = get_vars_maybe_avg(['V','g','b'], ema) 169 | # tf.assert_variables_initialized([V,g,b]) 170 | 171 | # use weight normalization (Salimans & Kingma, 2016) 172 | x = tf.matmul(x, V) 173 | scaler = g/tf.sqrt(tf.reduce_sum(tf.square(V),[0])) 174 | x = tf.reshape(scaler,[1,num_units])*x + tf.reshape(b,[1,num_units]) 175 | 176 | # apply nonlinearity 177 | if nonlinearity is not None: 178 | x = nonlinearity(x) 179 | return x 180 | 181 | @add_arg_scope 182 | def conv2d(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs): 183 | ''' convolutional layer ''' 184 | name = get_name('conv2d', counters) 185 | with tf.variable_scope(name): 186 | if init: 187 | # data based initialization of parameters 188 | V = tf.get_variable('V', filter_size+[int(x.get_shape()[-1]),num_filters], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True) 189 | V_norm = tf.nn.l2_normalize(V.initialized_value(), [0,1,2]) 190 | x_init = tf.nn.conv2d(x, V_norm, [1]+stride+[1], pad) 191 | m_init, v_init = tf.nn.moments(x_init, [0,1,2]) 192 | scale_init = init_scale/tf.sqrt(v_init + 1e-8) 193 | # g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=True) 194 | # b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init*scale_init, trainable=True) 195 | g = tf.get_variable('g', dtype=tf.float32, initializer=tf.constant(np.ones(num_filters),tf.float32), trainable=True) 196 | b = tf.get_variable('b', dtype=tf.float32, initializer=tf.constant(np.zeros(num_filters),tf.float32), trainable=True) 197 | # print(b) 198 | x_init = tf.reshape(scale_init,[1,1,1,num_filters])*(x_init-tf.reshape(m_init,[1,1,1,num_filters])) 199 | if nonlinearity is not None: 200 | x_init = nonlinearity(x_init) 201 | return x_init 202 | 203 | else: 204 | V, g, b = get_vars_maybe_avg(['V', 'g', 'b'], ema) 205 | # tf.assert_variables_initialized([V,g,b]) 206 | 207 | # use weight normalization (Salimans & Kingma, 2016) 208 | W = tf.reshape(g,[1,1,1,num_filters])*tf.nn.l2_normalize(V,[0,1,2]) 209 | 210 | # calculate convolutional layer output 211 | x = tf.nn.bias_add(tf.nn.conv2d(x, W, [1]+stride+[1], pad), b) 212 | 213 | # apply nonlinearity 214 | if nonlinearity is not None: 215 | x = nonlinearity(x) 216 | return x 217 | 218 | @add_arg_scope 219 | def deconv2d(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs): 220 | ''' transposed convolutional layer ''' 221 | name = get_name('deconv2d', counters) 222 | xs = int_shape(x) 223 | if pad=='SAME': 224 | target_shape = [xs[0], xs[1]*stride[0], xs[2]*stride[1], num_filters] 225 | else: 226 | target_shape = [xs[0], xs[1]*stride[0] + filter_size[0]-1, xs[2]*stride[1] + filter_size[1]-1, num_filters] 227 | with tf.variable_scope(name): 228 | if init: 229 | # data based initialization of parameters 230 | V = tf.get_variable('V', filter_size+[num_filters,int(x.get_shape()[-1])], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True) 231 | V_norm = tf.nn.l2_normalize(V.initialized_value(), [0,1,3]) 232 | x_init = tf.nn.conv2d_transpose(x, V_norm, target_shape, [1]+stride+[1], padding=pad) 233 | m_init, v_init = tf.nn.moments(x_init, [0,1,2]) 234 | scale_init = init_scale/tf.sqrt(v_init + 1e-8) 235 | # g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=True) 236 | # b = tf.get_variable('b', dtype=tf.float32,initializer=-m_init*scale_init, trainable=True) 237 | g = tf.get_variable('g', dtype=tf.float32, initializer=tf.constant(np.ones(num_filters),tf.float32), trainable=True) 238 | b = tf.get_variable('b', dtype=tf.float32,initializer=tf.constant(np.zeros(num_filters),tf.float32), trainable=True) 239 | # print(b) 240 | x_init = tf.reshape(scale_init,[1,1,1,num_filters])*(x_init-tf.reshape(m_init,[1,1,1,num_filters])) 241 | if nonlinearity is not None: 242 | x_init = nonlinearity(x_init) 243 | return x_init 244 | 245 | else: 246 | V, g, b = get_vars_maybe_avg(['V', 'g', 'b'], ema) 247 | # tf.assert_variables_initialized #deprecated on tf 1.3 248 | 249 | # use weight normalization (Salimans & Kingma, 2016)V = t 250 | W = tf.reshape(g,[1,1,num_filters,1])*tf.nn.l2_normalize(V,[0,1,3]) 251 | 252 | # calculate convolutional layer output 253 | x = tf.nn.conv2d_transpose(x, W, target_shape, [1]+stride+[1], padding=pad) 254 | x = tf.nn.bias_add(x, b) 255 | 256 | # apply nonlinearity 257 | if nonlinearity is not None: 258 | x = nonlinearity(x) 259 | return x 260 | 261 | @add_arg_scope 262 | def nin(x, num_units, **kwargs): 263 | """ a network in network layer (1x1 CONV) """ 264 | s = int_shape(x) 265 | x = tf.reshape(x, [np.prod(s[:-1]),s[-1]]) 266 | x = dense(x, num_units, **kwargs) 267 | return tf.reshape(x, s[:-1]+[num_units]) 268 | 269 | ''' meta-layer consisting of multiple base layers ''' 270 | 271 | @add_arg_scope 272 | def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs): 273 | xs = int_shape(x) 274 | num_filters = xs[-1] 275 | 276 | c1 = conv(nonlinearity(x), num_filters) 277 | if a is not None: # add short-cut connection if auxiliary input 'a' is given 278 | c1 += nin(nonlinearity(a), num_filters) 279 | c1 = nonlinearity(c1) 280 | if dropout_p > 0: 281 | c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p) 282 | c2 = conv(c1, num_filters * 2, init_scale=0.1) 283 | 284 | # add projection of h vector if included: conditional generation 285 | if h is not None: 286 | with tf.variable_scope(get_name('conditional_weights', counters)): 287 | hw = get_var_maybe_avg('hw', ema, shape=[int_shape(h)[-1], 2 * num_filters], dtype=tf.float32, 288 | initializer=tf.random_normal_initializer(0, 0.05), trainable=True) 289 | if init: 290 | hw = hw.initialized_value() 291 | c2 += tf.reshape(tf.matmul(h, hw), [xs[0], 1, 1, 2 * num_filters]) 292 | 293 | a, b = tf.split(3, 2, c2) 294 | c3 = a * tf.nn.sigmoid(b) 295 | return x + c3 296 | 297 | ''' utilities for shifting the image around, efficient alternative to masking convolutions ''' 298 | 299 | def down_shift(x): 300 | xs = int_shape(x) 301 | return tf.concat(1,[tf.zeros([xs[0],1,xs[2],xs[3]]), x[:,:xs[1]-1,:,:]]) 302 | 303 | def right_shift(x): 304 | xs = int_shape(x) 305 | return tf.concat(2,[tf.zeros([xs[0],xs[1],1,xs[3]]), x[:,:,:xs[2]-1,:]]) 306 | 307 | @add_arg_scope 308 | def down_shifted_conv2d(x, num_filters, filter_size=[2,3], stride=[1,1], **kwargs): 309 | x = tf.pad(x, [[0,0],[filter_size[0]-1,0], [int((filter_size[1]-1)/2),int((filter_size[1]-1)/2)],[0,0]]) 310 | return conv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs) 311 | 312 | @add_arg_scope 313 | def down_shifted_deconv2d(x, num_filters, filter_size=[2,3], stride=[1,1], **kwargs): 314 | x = deconv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs) 315 | xs = int_shape(x) 316 | return x[:,:(xs[1]-filter_size[0]+1),int((filter_size[1]-1)/2):(xs[2]-int((filter_size[1]-1)/2)),:] 317 | 318 | @add_arg_scope 319 | def down_right_shifted_conv2d(x, num_filters, filter_size=[2,2], stride=[1,1], **kwargs): 320 | x = tf.pad(x, [[0,0],[filter_size[0]-1, 0], [filter_size[1]-1, 0],[0,0]]) 321 | return conv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs) 322 | 323 | @add_arg_scope 324 | def down_right_shifted_deconv2d(x, num_filters, filter_size=[2,2], stride=[1,1], **kwargs): 325 | x = deconv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs) 326 | xs = int_shape(x) 327 | return x[:,:(xs[1]-filter_size[0]+1):,:(xs[2]-filter_size[1]+1),:] 328 | -------------------------------------------------------------------------------- /svhn_gan.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import nn # OpenAI implemetation of weightnormalization (Salimans & Kingma) 3 | 4 | init_kernel = tf.random_normal_initializer(mean=0, stddev=0.05) 5 | 6 | 7 | def leakyReLu(x, alpha=0.2, name=None): 8 | if name: 9 | with tf.variable_scope(name): 10 | return _leakyReLu_impl(x, alpha) 11 | else: 12 | return _leakyReLu_impl(x, alpha) 13 | 14 | 15 | def _leakyReLu_impl(x, alpha): 16 | return tf.nn.relu(x) - (alpha * tf.nn.relu(-x)) 17 | 18 | def gaussian_noise_layer(input_layer, std): 19 | noise = tf.random_normal(shape=tf.shape(input_layer), mean=0.0, stddev=std, dtype=tf.float32) 20 | return input_layer + noise 21 | 22 | 23 | def discriminator(inp, is_training, init=False, reuse=False, getter =None): 24 | with tf.variable_scope('discriminator_model', reuse=reuse,custom_getter=getter): 25 | counter = {} 26 | x = tf.reshape(inp, [-1, 32, 32, 3]) 27 | 28 | x = tf.layers.dropout(x, rate=0.2, training=is_training, name='dropout_0') 29 | 30 | x = nn.conv2d(x, 64, nonlinearity=leakyReLu, init=init, counters=counter) 31 | x = nn.conv2d(x, 64, nonlinearity=leakyReLu, init=init, counters=counter) 32 | x = nn.conv2d(x, 64, stride=[2, 2], nonlinearity=leakyReLu, init=init, counters=counter) 33 | 34 | x = tf.layers.dropout(x, rate=0.5, training=is_training, name='dropout_1') 35 | 36 | x = nn.conv2d(x, 128, nonlinearity=leakyReLu, init=init, counters=counter) 37 | x = nn.conv2d(x, 128, nonlinearity=leakyReLu, init=init, counters=counter) 38 | x = nn.conv2d(x, 128, stride=[2, 2], nonlinearity=leakyReLu, init=init, counters=counter) 39 | 40 | x = tf.layers.dropout(x, rate=0.5, training=is_training, name='dropout_2') 41 | 42 | x = nn.conv2d(x, 128, pad='VALID', nonlinearity=leakyReLu, init=init, counters=counter) 43 | x = nn.nin(x, 128, counters=counter, nonlinearity=leakyReLu, init=init) 44 | x = nn.nin(x, 128, counters=counter, nonlinearity=leakyReLu, init=init) 45 | x = tf.layers.max_pooling2d(x, pool_size=6, strides=1, 46 | name='avg_pool_0') 47 | x = tf.squeeze(x, [1, 2]) 48 | 49 | intermediate_layer = x 50 | 51 | logits = nn.dense(x, 10, nonlinearity=None, init=init, counters=counter, init_scale=0.1) 52 | 53 | return logits, intermediate_layer 54 | 55 | 56 | def generator(z_seed, is_training, init=False,reuse=False): 57 | with tf.variable_scope('generator_model', reuse=reuse): 58 | counter = {} 59 | x = z_seed 60 | with tf.variable_scope('dense_1'): 61 | x = tf.layers.dense(x, units=4 * 4 * 512, kernel_initializer=init_kernel) 62 | x = tf.layers.batch_normalization(x, training=is_training, name='batchnorm_1') 63 | x = tf.nn.relu(x) 64 | 65 | x = tf.reshape(x, [-1, 4, 4, 512]) 66 | 67 | with tf.variable_scope('deconv_1'): 68 | x = tf.layers.conv2d_transpose(x, 256, [5, 5], strides=[2, 2], padding='SAME', kernel_initializer=init_kernel) 69 | x = tf.layers.batch_normalization(x, training=is_training, name='batchnorm_2') 70 | x = tf.nn.relu(x) 71 | 72 | with tf.variable_scope('deconv_2'): 73 | x = tf.layers.conv2d_transpose(x, 128, [5, 5], strides=[2, 2], padding='SAME', kernel_initializer=init_kernel) 74 | x = tf.layers.batch_normalization(x, training=is_training, name='batchnormn_3') 75 | x = tf.nn.relu(x) 76 | 77 | with tf.variable_scope('deconv_3'): 78 | output = nn.deconv2d(x, num_filters=3, filter_size=[5, 5], stride=[2, 2], nonlinearity=tf.tanh, init=init, 79 | counters=counter, init_scale=0.1) 80 | return output 81 | -------------------------------------------------------------------------------- /train_cifar.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from data import cifar10_input 7 | from cifar_gan import discriminator, generator 8 | import sys 9 | import os 10 | 11 | # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 12 | 13 | flags = tf.app.flags 14 | flags.DEFINE_integer('gpu', 0, 'gpu [0]') 15 | flags.DEFINE_integer('batch_size', 100, "batch size [100]") 16 | flags.DEFINE_string('data_dir', '/tmp/data/cifar-10-python/','data directory') 17 | flags.DEFINE_string('logdir', './log/cifar', 'log directory') 18 | flags.DEFINE_integer('seed', 10, 'seed numpy') 19 | flags.DEFINE_integer('labeled', 400, 'labeled data per class [400]') 20 | flags.DEFINE_float('learning_rate', 0.0003, 'learning_rate[0.0003]') 21 | flags.DEFINE_float('unl_weight', 1.0, 'unlabeled weight [1.]') 22 | flags.DEFINE_float('lbl_weight', 1.0, 'unlabeled weight [1.]') 23 | flags.DEFINE_float('ma_decay', 0.9999, 'exponential moving average for inference [0.9999]') 24 | flags.DEFINE_integer('decay_start', 1000, 'start learning rate decay [1200]') 25 | flags.DEFINE_integer('epoch', 1200, 'epochs [1400]') 26 | flags.DEFINE_boolean('validation', False, 'validation [False]') 27 | 28 | flags.DEFINE_integer('freq_print', 10000, 'frequency image print tensorboard [10000]') 29 | flags.DEFINE_integer('step_print', 50, 'frequency scalar print tensorboard [50]') 30 | flags.DEFINE_integer('freq_test', 1, 'frequency test [500]') 31 | flags.DEFINE_integer('freq_save', 10, 'frequency saver epoch[50]') 32 | FLAGS = flags.FLAGS 33 | 34 | 35 | def get_getter(ema): 36 | def ema_getter(getter, name, *args, **kwargs): 37 | var = getter(name, *args, **kwargs) 38 | ema_var = ema.average(var) 39 | return ema_var if ema_var else var 40 | return ema_getter 41 | 42 | 43 | def display_progression_epoch(j, id_max): 44 | batch_progression = int((j / id_max) * 100) 45 | sys.stdout.write(str(batch_progression) + ' % epoch' + chr(13)) 46 | _ = sys.stdout.flush 47 | 48 | 49 | def linear_decay(decay_start, decay_end, epoch): 50 | return min(-1 / (decay_end - decay_start) * epoch + 1 + decay_start / (decay_end - decay_start),1) 51 | 52 | 53 | def main(_): 54 | print("\nParameters:") 55 | for attr,value in tf.app.flags.FLAGS.flag_values_dict().items(): 56 | print("{}={}".format(attr,value)) 57 | print("") 58 | 59 | os.environ["CUDA_VISIBLE_DEVICES"] = str(FLAGS.gpu) 60 | 61 | if not os.path.exists(FLAGS.logdir): 62 | os.makedirs(FLAGS.logdir) 63 | 64 | # Random seed 65 | rng = np.random.RandomState(FLAGS.seed) # seed labels 66 | rng_data = np.random.RandomState(rng.randint(0, 2**10)) # seed shuffling 67 | 68 | # load CIFAR-10 69 | trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir, 'train') # float [-1 1] images 70 | testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test') 71 | trainx_unl = trainx.copy() 72 | trainx_unl2 = trainx.copy() 73 | 74 | if FLAGS.validation: 75 | split = int(0.1 * trainx.shape[0]) 76 | print("validation enabled") 77 | testx = trainx[:split] 78 | testy = trainy[:split] 79 | trainx = trainx[split:] 80 | trainy = trainy[split:] 81 | 82 | nr_batches_train = int(trainx.shape[0] / FLAGS.batch_size) 83 | nr_batches_test = int(testx.shape[0] / FLAGS.batch_size) 84 | 85 | # select labeled data 86 | inds = rng_data.permutation(trainx.shape[0]) 87 | trainx = trainx[inds] 88 | trainy = trainy[inds] 89 | txs = [] 90 | tys = [] 91 | for j in range(10): 92 | txs.append(trainx[trainy == j][:FLAGS.labeled]) 93 | tys.append(trainy[trainy == j][:FLAGS.labeled]) 94 | txs = np.concatenate(txs, axis=0) 95 | tys = np.concatenate(tys, axis=0) 96 | 97 | '''construct graph''' 98 | unl = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3], name='unlabeled_data_input_pl') 99 | is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl') 100 | inp = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3], name='labeled_data_input_pl') 101 | lbl = tf.placeholder(tf.int32, [FLAGS.batch_size], name='lbl_input_pl') 102 | # scalar pl 103 | lr_pl = tf.placeholder(tf.float32, [], name='learning_rate_pl') 104 | acc_train_pl = tf.placeholder(tf.float32, [], 'acc_train_pl') 105 | acc_test_pl = tf.placeholder(tf.float32, [], 'acc_test_pl') 106 | acc_test_pl_ema = tf.placeholder(tf.float32, [], 'acc_test_pl') 107 | 108 | random_z = tf.random_uniform([FLAGS.batch_size, 100], name='random_z') 109 | generator(random_z, is_training_pl, init=True) # init of weightnorm weights 110 | gen_inp = generator(random_z, is_training_pl, init=False, reuse=True) 111 | discriminator(unl, is_training_pl, init=True) 112 | logits_lab, _ = discriminator(inp, is_training_pl, init=False, reuse=True) 113 | logits_gen, layer_fake = discriminator(gen_inp, is_training_pl, init=False, reuse=True) 114 | logits_unl, layer_real = discriminator(unl, is_training_pl, init=False, reuse=True) 115 | 116 | with tf.name_scope('loss_functions'): 117 | # discriminator 118 | l_unl = tf.reduce_logsumexp(logits_unl, axis=1) 119 | l_gen = tf.reduce_logsumexp(logits_gen, axis=1) 120 | loss_lab = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=lbl, logits=logits_lab)) 121 | loss_unl = - 0.5 * tf.reduce_mean(l_unl) \ 122 | + 0.5 * tf.reduce_mean(tf.nn.softplus(l_unl)) \ 123 | + 0.5 * tf.reduce_mean(tf.nn.softplus(l_gen)) 124 | 125 | # generator 126 | m1 = tf.reduce_mean(layer_real, axis=0) 127 | m2 = tf.reduce_mean(layer_fake, axis=0) 128 | 129 | loss_dis = FLAGS.unl_weight * loss_unl + FLAGS.lbl_weight * loss_lab 130 | loss_gen = tf.reduce_mean(tf.abs(m1 - m2)) 131 | correct_pred = tf.equal(tf.cast(tf.argmax(logits_lab, 1), tf.int32), tf.cast(lbl, tf.int32)) 132 | accuracy_classifier = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) 133 | 134 | 135 | with tf.name_scope('optimizers'): 136 | # control op dependencies for batch norm and trainable variables 137 | tvars = tf.trainable_variables() 138 | dvars = [var for var in tvars if 'discriminator_model' in var.name] 139 | gvars = [var for var in tvars if 'generator_model' in var.name] 140 | 141 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 142 | update_ops_gen = [x for x in update_ops if ('generator_model' in x.name)] 143 | update_ops_dis = [x for x in update_ops if ('discriminator_model' in x.name)] 144 | optimizer_dis = tf.train.AdamOptimizer(learning_rate=lr_pl, beta1=0.5, name='dis_optimizer') 145 | optimizer_gen = tf.train.AdamOptimizer(learning_rate=lr_pl, beta1=0.5, name='gen_optimizer') 146 | 147 | with tf.control_dependencies(update_ops_gen): 148 | train_gen_op = optimizer_gen.minimize(loss_gen, var_list=gvars) 149 | 150 | dis_op = optimizer_dis.minimize(loss_dis, var_list=dvars) 151 | ema = tf.train.ExponentialMovingAverage(decay=FLAGS.ma_decay) 152 | maintain_averages_op = ema.apply(dvars) 153 | 154 | with tf.control_dependencies([dis_op]): 155 | train_dis_op = tf.group(maintain_averages_op) 156 | 157 | logits_ema, _ = discriminator(inp, is_training_pl, getter=get_getter(ema), reuse=True) 158 | correct_pred_ema = tf.equal(tf.cast(tf.argmax(logits_ema, 1), tf.int32), tf.cast(lbl, tf.int32)) 159 | accuracy_ema = tf.reduce_mean(tf.cast(correct_pred_ema, tf.float32)) 160 | 161 | with tf.name_scope('summary'): 162 | with tf.name_scope('discriminator'): 163 | tf.summary.scalar('loss_discriminator', loss_dis, ['dis']) 164 | 165 | with tf.name_scope('generator'): 166 | tf.summary.scalar('loss_generator', loss_gen, ['gen']) 167 | 168 | with tf.name_scope('images'): 169 | tf.summary.image('gen_images', gen_inp, 10, ['image']) 170 | 171 | with tf.name_scope('epoch'): 172 | tf.summary.scalar('accuracy_train', acc_train_pl, ['epoch']) 173 | tf.summary.scalar('accuracy_test_moving_average', acc_test_pl_ema, ['epoch']) 174 | tf.summary.scalar('accuracy_test', acc_test_pl, ['epoch']) 175 | tf.summary.scalar('learning_rate', lr_pl, ['epoch']) 176 | 177 | sum_op_dis = tf.summary.merge_all('dis') 178 | sum_op_gen = tf.summary.merge_all('gen') 179 | sum_op_im = tf.summary.merge_all('image') 180 | sum_op_epoch = tf.summary.merge_all('epoch') 181 | 182 | # training global varialble 183 | global_epoch = tf.Variable(0, trainable=False, name='global_epoch') 184 | global_step = tf.Variable(0, trainable=False, name='global_step') 185 | inc_global_step = tf.assign(global_step, global_step+1) 186 | inc_global_epoch = tf.assign(global_epoch, global_epoch+1) 187 | 188 | # op initializer for session manager 189 | init_gen = [var.initializer for var in gvars][:-3] 190 | with tf.control_dependencies(init_gen): 191 | op = tf.global_variables_initializer() 192 | init_feed_dict = {inp: trainx_unl[:FLAGS.batch_size], unl: trainx_unl[:FLAGS.batch_size], is_training_pl: True} 193 | 194 | sv = tf.train.Supervisor(logdir=FLAGS.logdir, global_step=global_epoch, summary_op=None, save_model_secs=0, 195 | init_op=op,init_feed_dict=init_feed_dict) 196 | 197 | '''//////training //////''' 198 | print('start training') 199 | with sv.managed_session() as sess: 200 | tf.set_random_seed(rng.randint(2 ** 10)) 201 | print('\ninitialization done') 202 | print('Starting training from epoch :%d, step:%d \n'%(sess.run(global_epoch),sess.run(global_step))) 203 | 204 | writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph) 205 | 206 | while not sv.should_stop(): 207 | epoch = sess.run(global_epoch) 208 | train_batch = sess.run(global_step) 209 | 210 | if (epoch >= FLAGS.epoch): 211 | print("Training done") 212 | sv.stop() 213 | break 214 | 215 | begin = time.time() 216 | train_loss_lab=train_loss_unl=train_loss_gen=train_acc=test_acc=test_acc_ma= 0 217 | lr = FLAGS.learning_rate * linear_decay(FLAGS.decay_start,FLAGS.epoch,epoch) 218 | 219 | # construct randomly permuted batches 220 | trainx = [] 221 | trainy = [] 222 | for t in range(int(np.ceil(trainx_unl.shape[0] / float(txs.shape[0])))): # same size lbl and unlb 223 | inds = rng.permutation(txs.shape[0]) 224 | trainx.append(txs[inds]) 225 | trainy.append(tys[inds]) 226 | trainx = np.concatenate(trainx, axis=0) 227 | trainy = np.concatenate(trainy, axis=0) 228 | trainx_unl = trainx_unl[rng.permutation(trainx_unl.shape[0])] # shuffling unl dataset 229 | trainx_unl2 = trainx_unl2[rng.permutation(trainx_unl2.shape[0])] 230 | 231 | # training 232 | for t in range(nr_batches_train): 233 | 234 | display_progression_epoch(t, nr_batches_train) 235 | ran_from = t * FLAGS.batch_size 236 | ran_to = (t + 1) * FLAGS.batch_size 237 | 238 | # train discriminator 239 | feed_dict = {unl: trainx_unl[ran_from:ran_to], 240 | is_training_pl: True, 241 | inp: trainx[ran_from:ran_to], 242 | lbl: trainy[ran_from:ran_to], 243 | lr_pl: lr} 244 | _, acc, lu, lb, sm = sess.run([train_dis_op, accuracy_classifier, loss_lab, loss_unl, sum_op_dis], 245 | feed_dict=feed_dict) 246 | train_loss_unl += lu 247 | train_loss_lab += lb 248 | train_acc += acc 249 | if (train_batch % FLAGS.step_print) == 0: 250 | writer.add_summary(sm, train_batch) 251 | 252 | # train generator 253 | _, lg, sm = sess.run([train_gen_op, loss_gen, sum_op_gen], feed_dict={unl: trainx_unl2[ran_from:ran_to], 254 | is_training_pl: True, 255 | lr_pl: lr}) 256 | train_loss_gen += lg 257 | if (train_batch % FLAGS.step_print) == 0: 258 | writer.add_summary(sm, train_batch) 259 | 260 | if (train_batch % FLAGS.freq_print == 0) & (train_batch != 0): 261 | ran_from = np.random.randint(0, trainx_unl.shape[0] - FLAGS.batch_size) 262 | ran_to = ran_from + FLAGS.batch_size 263 | sm = sess.run(sum_op_im, 264 | feed_dict={is_training_pl: True, unl: trainx_unl[ran_from:ran_to]}) 265 | writer.add_summary(sm, train_batch) 266 | 267 | train_batch += 1 268 | sess.run(inc_global_step) 269 | 270 | train_loss_lab /= nr_batches_train 271 | train_loss_unl /= nr_batches_train 272 | train_loss_gen /= nr_batches_train 273 | train_acc /= nr_batches_train 274 | 275 | # Testing moving averaged model and raw model 276 | if (epoch % FLAGS.freq_test == 0) | (epoch == FLAGS.epoch-1): 277 | for t in range(nr_batches_test): 278 | ran_from = t * FLAGS.batch_size 279 | ran_to = (t + 1) * FLAGS.batch_size 280 | feed_dict = {inp: testx[ran_from:ran_to], 281 | lbl: testy[ran_from:ran_to], 282 | is_training_pl: False} 283 | acc, acc_ema = sess.run([accuracy_classifier, accuracy_ema], feed_dict=feed_dict) 284 | test_acc += acc 285 | test_acc_ma += acc_ema 286 | test_acc /= nr_batches_test 287 | test_acc_ma /= nr_batches_test 288 | 289 | sum = sess.run(sum_op_epoch, feed_dict={acc_train_pl: train_acc, 290 | acc_test_pl: test_acc, 291 | acc_test_pl_ema: test_acc_ma, 292 | lr_pl: lr}) 293 | writer.add_summary(sum, epoch) 294 | 295 | print( 296 | "Epoch %d | time = %ds | loss gen = %.4f | loss lab = %.4f | loss unl = %.4f " 297 | "| train acc = %.4f| test acc = %.4f | test acc ema = %0.4f" 298 | % (epoch, time.time() - begin, train_loss_gen, train_loss_lab, train_loss_unl, train_acc, 299 | test_acc, test_acc_ma)) 300 | 301 | sess.run(inc_global_epoch) 302 | 303 | # save snapshots of model 304 | if ((epoch % FLAGS.freq_save == 0) & (epoch!=0) ) | (epoch == FLAGS.epoch-1): 305 | string = 'model-' + str(epoch) 306 | save_path = os.path.join(FLAGS.logdir, string) 307 | sv.saver.save(sess, save_path) 308 | print("Model saved in file: %s" % (save_path)) 309 | 310 | 311 | if __name__ == '__main__': 312 | tf.app.run() 313 | -------------------------------------------------------------------------------- /train_svhn.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from data import svhn_data 7 | from cifar_gan import discriminator, generator 8 | import sys 9 | import os 10 | 11 | # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 12 | 13 | flags = tf.app.flags 14 | flags.DEFINE_integer('gpu', 0, 'labeled data per class') 15 | flags.DEFINE_integer('batch_size', 100, "batch size [50]") 16 | flags.DEFINE_string('data_dir', '/tmp/data/svhn/','') 17 | flags.DEFINE_string('logdir', './log/svhn', 'log directory') 18 | flags.DEFINE_integer('seed', 10, 'seed numpy') 19 | flags.DEFINE_integer('labeled', 100, 'labeled data per class') 20 | flags.DEFINE_float('learning_rate', 0.0003, 'learning_rate[0.003]') 21 | flags.DEFINE_float('unl_weight', 1.0, 'unlabeled weight [1.]') 22 | flags.DEFINE_float('lbl_weight', 1.0, 'unlabeled weight [1.]') 23 | flags.DEFINE_float('ma_decay', 0.9999, 'exp moving average for inference [0.9999]') 24 | flags.DEFINE_boolean('validation', False, 'enable manifold reg') 25 | flags.DEFINE_integer('decay_start', 399, 'start of learning rate decay') 26 | flags.DEFINE_integer('epoch', 400, 'labeled data per class') 27 | 28 | 29 | flags.DEFINE_integer('freq_print', 10000, 'frequency image print tensorboard [10000]') 30 | flags.DEFINE_integer('step_print', 50, 'frequency scalar print tensorboard [50]') 31 | flags.DEFINE_integer('freq_test', 1, 'frequency test [500]') 32 | flags.DEFINE_integer('freq_save', 10, 'frequency saver epoch[50]') 33 | FLAGS = flags.FLAGS 34 | 35 | def get_getter(ema): 36 | def ema_getter(getter, name, *args, **kwargs): 37 | var = getter(name, *args, **kwargs) 38 | ema_var = ema.average(var) 39 | return ema_var if ema_var else var 40 | return ema_getter 41 | 42 | 43 | def display_progression_epoch(j, id_max): 44 | batch_progression = int((j / id_max) * 100) 45 | sys.stdout.write(str(batch_progression) + ' % epoch' + chr(13)) 46 | _ = sys.stdout.flush 47 | 48 | 49 | def linear_decay(decay_start, decay_end, epoch): 50 | return min(-1 / (decay_end - decay_start) * epoch + 1 + decay_start / (decay_end - decay_start),1) 51 | 52 | 53 | def main(_): 54 | print("\nParameters:") 55 | for attr,value in tf.app.flags.FLAGS.flag_values_dict().items(): 56 | print("{}={}".format(attr,value)) 57 | print("") 58 | os.environ["CUDA_VISIBLE_DEVICES"] = str(FLAGS.gpu) 59 | 60 | if not os.path.exists(FLAGS.logdir): 61 | os.makedirs(FLAGS.logdir) 62 | 63 | # Random seed 64 | rng = np.random.RandomState(FLAGS.seed) # seed labels 65 | rng_data = np.random.RandomState(rng.randint(0, 2**10)) # seed shuffling 66 | 67 | trainx, trainy = svhn_data.load(FLAGS.data_dir, 'train') 68 | testx, testy = svhn_data.load(FLAGS.data_dir, 'test') 69 | def rescale(mat): 70 | return np.transpose(((-127.5 + mat) / 127.5), (3, 0, 1, 2)) 71 | trainx = rescale(trainx) 72 | testx = rescale(testx) 73 | 74 | if FLAGS.validation: 75 | split = int(0.1 * trainx.shape[0]) 76 | print("validation enabled") 77 | testx = trainx[:split] 78 | testy = trainy[:split] 79 | trainx = trainx[split:] 80 | trainy = trainy[split:] 81 | 82 | trainx_unl = trainx.copy() 83 | trainx_unl2 = trainx.copy() 84 | 85 | nr_batches_train = int(trainx.shape[0] / FLAGS.batch_size) 86 | nr_batches_test = int(testx.shape[0] / FLAGS.batch_size) 87 | 88 | # select labeled data 89 | inds = rng_data.permutation(trainx.shape[0]) 90 | trainx = trainx[inds] 91 | trainy = trainy[inds] 92 | txs = [] 93 | tys = [] 94 | for j in range(10): 95 | txs.append(trainx[trainy == j][:FLAGS.labeled]) 96 | tys.append(trainy[trainy == j][:FLAGS.labeled]) 97 | txs = np.concatenate(txs, axis=0) 98 | tys = np.concatenate(tys, axis=0) 99 | 100 | '''construct graph''' 101 | print('constructing graph') 102 | unl = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3], name='unlabeled_data_input_pl') 103 | is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl') 104 | inp = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3], name='labeled_data_input_pl') 105 | lbl = tf.placeholder(tf.int32, [FLAGS.batch_size], name='lbl_input_pl') 106 | # scalar pl 107 | lr_pl = tf.placeholder(tf.float32, [], name='learning_rate_pl') 108 | acc_train_pl = tf.placeholder(tf.float32, [], 'acc_train_pl') 109 | acc_test_pl = tf.placeholder(tf.float32, [], 'acc_test_pl') 110 | acc_test_pl_ema = tf.placeholder(tf.float32, [], 'acc_test_pl') 111 | 112 | random_z = tf.random_uniform([FLAGS.batch_size, 100], name='random_z') 113 | generator(random_z, is_training_pl, init=True) # init of weightnorm weights 114 | gen_inp = generator(random_z, is_training_pl, init=False, reuse=True) 115 | 116 | discriminator(unl, is_training_pl, init=True) 117 | logits_lab, _ = discriminator(inp, is_training_pl, init=False, reuse=True) 118 | logits_gen, layer_fake = discriminator(gen_inp, is_training_pl, init=False, reuse=True) 119 | logits_unl, layer_real = discriminator(unl, is_training_pl, init=False, reuse=True) 120 | 121 | with tf.name_scope('loss_functions'): 122 | l_unl = tf.reduce_logsumexp(logits_unl, axis=1) 123 | l_gen = tf.reduce_logsumexp(logits_gen, axis=1) 124 | # discriminator 125 | loss_lab = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=lbl, logits=logits_lab)) 126 | loss_unl = - 0.5 * tf.reduce_mean(l_unl) \ 127 | + 0.5 * tf.reduce_mean(tf.nn.softplus(l_unl)) \ 128 | + 0.5 * tf.reduce_mean(tf.nn.softplus(l_gen)) 129 | # generator 130 | m1 = tf.reduce_mean(layer_real, axis=0) 131 | m2 = tf.reduce_mean(layer_fake, axis=0) 132 | 133 | 134 | loss_dis = FLAGS.unl_weight * loss_unl + FLAGS.lbl_weight * loss_lab 135 | loss_gen = tf.reduce_mean(tf.abs(m1 - m2)) 136 | correct_pred = tf.equal(tf.cast(tf.argmax(logits_lab, 1), tf.int32), tf.cast(lbl, tf.int32)) 137 | accuracy_classifier = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) 138 | 139 | with tf.name_scope('optimizers'): 140 | # control op dependencies for batch norm and trainable variables 141 | tvars = tf.trainable_variables() 142 | 143 | dvars = [var for var in tvars if 'discriminator_model' in var.name] 144 | gvars = [var for var in tvars if 'generator_model' in var.name] 145 | 146 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 147 | update_ops_gen = [x for x in update_ops if ('generator_model' in x.name)] 148 | update_ops_dis = [x for x in update_ops if ('discriminator_model' in x.name)] 149 | 150 | optimizer_dis = tf.train.AdamOptimizer(learning_rate=lr_pl, beta1=0.5, name='dis_optimizer') 151 | optimizer_gen = tf.train.AdamOptimizer(learning_rate=lr_pl, beta1=0.5, name='gen_optimizer') 152 | 153 | with tf.control_dependencies(update_ops_gen): 154 | train_gen_op = optimizer_gen.minimize(loss_gen, var_list=gvars) 155 | 156 | dis_op = optimizer_dis.minimize(loss_dis, var_list=dvars) 157 | 158 | ema = tf.train.ExponentialMovingAverage(decay=FLAGS.ma_decay) 159 | maintain_averages_op = ema.apply(dvars) 160 | 161 | with tf.control_dependencies([dis_op]): 162 | train_dis_op = tf.group(maintain_averages_op) 163 | 164 | logits_ema, _ = discriminator(inp, is_training_pl, getter=get_getter(ema), reuse=True) 165 | correct_pred_ema = tf.equal(tf.cast(tf.argmax(logits_ema, 1), tf.int32), tf.cast(lbl, tf.int32)) 166 | accuracy_ema = tf.reduce_mean(tf.cast(correct_pred_ema, tf.float32)) 167 | 168 | with tf.name_scope('summary'): 169 | with tf.name_scope('discriminator'): 170 | tf.summary.scalar('loss_discriminator', loss_dis, ['dis']) 171 | 172 | with tf.name_scope('generator'): 173 | tf.summary.scalar('loss_generator', loss_gen, ['gen']) 174 | 175 | with tf.name_scope('images'): 176 | tf.summary.image('gen_images', gen_inp, 10, ['image']) 177 | 178 | with tf.name_scope('epoch'): 179 | tf.summary.scalar('accuracy_train', acc_train_pl, ['epoch']) 180 | tf.summary.scalar('accuracy_test_moving_average', acc_test_pl_ema, ['epoch']) 181 | tf.summary.scalar('accuracy_test_raw', acc_test_pl, ['epoch']) 182 | tf.summary.scalar('learning_rate', lr_pl, ['epoch']) 183 | 184 | sum_op_dis = tf.summary.merge_all('dis') 185 | sum_op_gen = tf.summary.merge_all('gen') 186 | sum_op_im = tf.summary.merge_all('image') 187 | sum_op_epoch = tf.summary.merge_all('epoch') 188 | 189 | # training global varialble 190 | global_epoch = tf.Variable(0, trainable=False, name='global_epoch') 191 | global_step = tf.Variable(0, trainable=False, name='global_step') 192 | inc_global_step = tf.assign(global_step, global_step+1) 193 | inc_global_epoch = tf.assign(global_epoch, global_epoch+1) 194 | 195 | # op initializer for session manager 196 | init_gen = [var.initializer for var in gvars][:-3] 197 | with tf.control_dependencies(init_gen): 198 | op = tf.global_variables_initializer() 199 | init_feed_dict = {inp: trainx_unl[:FLAGS.batch_size], unl: trainx_unl[:FLAGS.batch_size], 200 | is_training_pl: True} 201 | 202 | sv = tf.train.Supervisor(logdir=FLAGS.logdir, global_step=global_epoch, summary_op=None, save_model_secs=0, 203 | init_op=op,init_feed_dict=init_feed_dict) 204 | 205 | 206 | '''//////training //////''' 207 | print('start training') 208 | with sv.managed_session() as sess: 209 | tf.set_random_seed(rng.randint(2 ** 10)) 210 | print('\ninitialization done') 211 | print('Starting training from epoch :%d, step:%d \n'%(sess.run(global_epoch),sess.run(global_step))) 212 | 213 | writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph) 214 | 215 | while not sv.should_stop(): 216 | epoch = sess.run(global_epoch) 217 | train_batch = sess.run(global_step) 218 | 219 | if (epoch >= FLAGS.epoch): 220 | print("Training done") 221 | sv.stop() 222 | break 223 | 224 | begin = time.time() 225 | train_loss_lab=train_loss_unl=train_loss_gen=train_acc=test_acc=test_acc_ma= 0 226 | lr = FLAGS.learning_rate * linear_decay(FLAGS.decay_start,FLAGS.epoch,epoch) 227 | 228 | # construct randomly permuted batches 229 | trainx = [] 230 | trainy = [] 231 | for t in range(int(np.ceil(trainx_unl.shape[0] / float(txs.shape[0])))): # same size lbl and unlb 232 | inds = rng.permutation(txs.shape[0]) 233 | trainx.append(txs[inds]) 234 | trainy.append(tys[inds]) 235 | trainx = np.concatenate(trainx, axis=0) 236 | trainy = np.concatenate(trainy, axis=0) 237 | trainx_unl = trainx_unl[rng.permutation(trainx_unl.shape[0])] # shuffling unl dataset 238 | trainx_unl2 = trainx_unl2[rng.permutation(trainx_unl2.shape[0])] 239 | 240 | # training 241 | for t in range(nr_batches_train): 242 | 243 | display_progression_epoch(t, nr_batches_train) 244 | ran_from = t * FLAGS.batch_size 245 | ran_to = (t + 1) * FLAGS.batch_size 246 | 247 | # train discriminator 248 | feed_dict = {unl: trainx_unl[ran_from:ran_to], 249 | is_training_pl: True, 250 | inp: trainx[ran_from:ran_to], 251 | lbl: trainy[ran_from:ran_to], 252 | lr_pl: lr} 253 | _, acc, lu, lb, sm = sess.run([train_dis_op, accuracy_classifier, loss_lab, loss_unl, sum_op_dis], 254 | feed_dict=feed_dict) 255 | train_loss_unl += lu 256 | train_loss_lab += lb 257 | train_acc += acc 258 | if (train_batch % FLAGS.step_print) == 0: 259 | writer.add_summary(sm, train_batch) 260 | 261 | # train generator 262 | _, lg, sm = sess.run([train_gen_op, loss_gen, sum_op_gen], feed_dict={unl: trainx_unl2[ran_from:ran_to], 263 | is_training_pl: True, 264 | lr_pl: lr}) 265 | train_loss_gen += lg 266 | if (train_batch % FLAGS.step_print) == 0: 267 | writer.add_summary(sm, train_batch) 268 | 269 | if (train_batch % FLAGS.freq_print == 0) & (train_batch != 0): 270 | ran_from = np.random.randint(0, trainx_unl.shape[0] - FLAGS.batch_size) 271 | ran_to = ran_from + FLAGS.batch_size 272 | sm = sess.run(sum_op_im, 273 | feed_dict={is_training_pl: True, unl: trainx_unl[ran_from:ran_to]}) 274 | writer.add_summary(sm, train_batch) 275 | 276 | train_batch += 1 277 | sess.run(inc_global_step) 278 | 279 | train_loss_lab /= nr_batches_train 280 | train_loss_unl /= nr_batches_train 281 | train_loss_gen /= nr_batches_train 282 | train_acc /= nr_batches_train 283 | 284 | # Testing moving averaged model and raw model 285 | if (epoch % FLAGS.freq_test == 0) | (epoch == FLAGS.epoch-1): 286 | for t in range(nr_batches_test): 287 | ran_from = t * FLAGS.batch_size 288 | ran_to = (t + 1) * FLAGS.batch_size 289 | feed_dict = {inp: testx[ran_from:ran_to], 290 | lbl: testy[ran_from:ran_to], 291 | is_training_pl: False} 292 | acc, acc_ema = sess.run([accuracy_classifier, accuracy_ema], feed_dict=feed_dict) 293 | test_acc += acc 294 | test_acc_ma += acc_ema 295 | test_acc /= nr_batches_test 296 | test_acc_ma /= nr_batches_test 297 | 298 | sum = sess.run(sum_op_epoch, feed_dict={acc_train_pl: train_acc, 299 | acc_test_pl: test_acc, 300 | acc_test_pl_ema: test_acc_ma, 301 | lr_pl: lr}) 302 | writer.add_summary(sum, epoch) 303 | 304 | print( 305 | "Epoch %d | time = %ds | loss gen = %.4f | loss lab = %.4f | loss unl = %.4f " 306 | "| train acc = %.4f| test acc = %.4f | test acc ema = %0.4f" 307 | % (epoch, time.time() - begin, train_loss_gen, train_loss_lab, train_loss_unl, train_acc, 308 | test_acc, test_acc_ma)) 309 | 310 | sess.run(inc_global_epoch) 311 | 312 | # save snap shot of model 313 | if ((epoch % FLAGS.freq_save == 0) & (epoch!=0) ) | (epoch == FLAGS.epoch-1): 314 | string = 'model-' + str(epoch) 315 | save_path = os.path.join(FLAGS.logdir, string) 316 | sv.saver.save(sess, save_path) 317 | print("Model saved in file: %s" % (save_path)) 318 | 319 | 320 | if __name__ == '__main__': 321 | tf.app.run() 322 | --------------------------------------------------------------------------------