├── README.md ├── glow ├── .gitignore ├── Glow.py ├── config_glow.py ├── flow_layers.py ├── main.py ├── nets.py ├── tf_ops.py └── util │ ├── HandleCIFAR10.py │ ├── HandleSVHN.py │ ├── cifar10.py │ ├── dataset_utils.py │ ├── svhn.py │ └── utils.py ├── lvat ├── .gitignore ├── cifar10.py ├── cnn.py ├── dataset_utils.py ├── layers_vat.py ├── svhn.py ├── train_semisup.py ├── utils.py └── vat.py └── vae ├── .gitignore ├── VAE.py ├── build_AE.py ├── config.py └── util ├── HandleIIDDataTFRecord.py ├── HandleImageDataNumpy.py ├── cifar10.py ├── dataset_utils.py ├── layers.py ├── losses.py └── svhn.py /README.md: -------------------------------------------------------------------------------- 1 | # LVAT 2 | The code used in the evaluation of Latent Space Virtual Adversarial Training (ECCV 2020 Oral). 3 | 4 | ## Reference 5 | This code is written based on the original VAT implementation (see [here](https://github.com/takerum/vat_tf)), 6 | and the implementation for Glow is heavily based on [this code](https://github.com/kmkolasinski/deep-learning-notes/tree/3c7779ea0063896bb3a759efa3e52d173aaae94b/seminars/2018-10-Normalizing-Flows-NICE-RealNVP-GLOW). 7 | 8 | ## Requirements 9 | tensorflow-gpu 1.14 10 | 11 | ## Preparation 1. Create symbolic links 12 | 13 | ``` 14 | cd lvat/ 15 | # for LVAT-VAE 16 | ln -s ../vae/VAE.py . 17 | ln -s ../vae/config.py . 18 | ln -s ../vae/util . 19 | ln -s ../vae/out_VAE_SVHN/ out_VAE_SVHN 20 | ln -s ../vae/out_VAE_SVHN_aug/ out_VAE_SVHN_aug 21 | ln -s ../vae/out_VAE_CIFAR10/ out_VAE_CIFAR10 22 | ln -s ../vae/out_VAE_CIFAR10_aug/ out_VAE_CIFAR10_aug 23 | 24 | # for LVAT-Glow 25 | ln -s ../glow/out/SVHN/w_128__step_22__scale_3__b_128/ out_Glow_SVHN 26 | ln -s ../glow/out/SVHN_aug/w_128__step_22__scale_3__b_128/ out_Glow_SVHN_aug 27 | ln -s ../glow/out/CIFAR10/w_128__step_22__scale_3__b_128/ out_Glow_CIFAR10 28 | ln -s ../glow/out/CIFAR10_aug/w_128__step_22__scale_3__b_128/ out_Glow_CIFAR10_aug 29 | ``` 30 | 31 | ## Preparation 2. Create tfrecords 32 | 33 | ``` 34 | cd vae/util/ 35 | ``` 36 | and 37 | ``` 38 | python svhn.py --data_dir= 39 | ``` 40 | or 41 | ``` 42 | python cifar10.py --data_dir= 43 | ``` 44 | 45 | ## Preparation 3. Building transfomer(VAE/Glow) 46 | 47 | For VAE, 48 | ``` 49 | cd vae 50 | python build_AE.py 51 | ``` 52 | and for Glow, 53 | ``` 54 | cd glow 55 | python main.py 56 | ``` 57 | For both, target datasets are identified in config.py (for VAE) or config_glow.py (for Glow). 58 | For VAE, you have to change the directory name where the trained model will be saved so as to accord to the one referred to by the symbolic links created in Preparation 1. 59 | For Glow, the trained model can be referred as it is. 60 | 61 | 62 | 63 | ## Training Classifier with LVAT 64 | ``` 65 | cd lvat 66 | ``` 67 | and for example 68 | ``` 69 | python train_semisup.py --data_set=SVHN --num_epochs=200 70 | --epoch_decay_start=80 --epsilon=1.5 --top_bn --method=lvat --log__dir=./out 71 | --data__dir= --num_iter_per_epoch=400 72 | --batch_size=32 --ul_batch_size=128 --num_labeled_examples=1000 73 | --is_aug=True --ae_type=Glow 74 | ``` 75 | # Description 76 | - `--data_set` can choose from either `SVHN` or `CIFAR10`. 77 | - that will be given as `--data__dir` should be the same as the one you identified in the above Preparation 2. 78 | - `--log__dir` is the directory name where the check-point file will be saved. 79 | - `--epsilon` is the magnitude of perturbation, which is commonly used for both `--method=lvat` and `--method=vat`. 80 | - If you set `--metod=vat`, then it works as the original vat. 81 | - You can choose the transfomer from VAE and Glow by `--ae_type=VAE` and `--ae_type=Glow`, respectively. 82 | 83 | # Important 84 | For SVHN, `--top_bn` option is necessary to achieve good results. 85 | -------------------------------------------------------------------------------- /glow/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /glow/Glow.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | import os, sys 6 | 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | from scipy.stats import norm 11 | import matplotlib.pyplot as plt 12 | from tqdm import tqdm 13 | 14 | 15 | import nets 16 | import flow_layers as fl 17 | 18 | import config_glow as c 19 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/util/') 20 | 21 | import utils 22 | 23 | tf.set_random_seed(0) 24 | 25 | 26 | class Glow(): 27 | def __init__(self): 28 | 29 | nn_template_fn = nets.OpenAITemplate( 30 | width=c.WIDTH_RESNET 31 | ) 32 | 33 | layers, self.actnorm_layers = nets.create_simple_flow( 34 | num_steps=c.N_FLOW_STEPS, 35 | num_scales=c.N_FLOW_SCALES, 36 | template_fn=nn_template_fn 37 | ) 38 | self.model_flow = fl.ChainLayer(layers) 39 | self.quantize_image_layer = layers[0] 40 | 41 | 42 | def encoder(self, x): 43 | 44 | #if x is None: 45 | # x = self.x_image_ph 46 | #flow = fl.InputLayer(self.x_image_ph) 47 | 48 | flow = fl.InputLayer(x) 49 | output_flow = self.model_flow(flow, forward=True) 50 | 51 | 52 | # ## Prepare output tensors 53 | 54 | y, logdet, z = output_flow 55 | return y, logdet, z 56 | #self.y, self.logdet, self.z = output_flow 57 | #return self.y, self.logdet, self.z 58 | 59 | 60 | def decoder(self, flow, is_to_uint8=False): 61 | 62 | """ flow: (y,logdet,z) """ 63 | 64 | #if flow is None: 65 | # flow = (self.y_ph, self.logdet_pior_tmp, self.z_ph ) 66 | 67 | x_reconst, _, _ = self.model_flow(flow, forward=False) 68 | if is_to_uint8: 69 | # return [0,255] non-differentiable 70 | return self.quantize_image_layer.to_uint8(x_reconst) 71 | else: 72 | # return [0,1] differentiable 73 | return self.quantize_image_layer.defferentiable_quantize(x_reconst)/255.0 74 | 75 | def build_graph_train(self, x=None): 76 | 77 | with tf.variable_scope('encoder') as scope: 78 | y, logdet, z = self.encoder(x) 79 | 80 | self.y, self.logdet, self.z = y, logdet, z 81 | logdet_pior_tmp = tf.zeros_like(logdet) 82 | 83 | 84 | ################################################# 85 | """ Loss """ 86 | ################################################# 87 | # 88 | # * Here simply the $-logp(x)$ 89 | 90 | tfd = tf.contrib.distributions 91 | 92 | self.beta_ph = tf.placeholder(tf.float32, []) 93 | 94 | y_flatten = tf.reshape(y, [c.BATCH_SIZE, -1]) 95 | z_flatten = tf.reshape(z, [c.BATCH_SIZE, -1]) 96 | 97 | prior_y = tfd.MultivariateNormalDiag(loc=tf.zeros_like(y_flatten), scale_diag=self.beta_ph * tf.ones_like(y_flatten)) 98 | prior_z = tfd.MultivariateNormalDiag(loc=tf.zeros_like(z_flatten), scale_diag=self.beta_ph * tf.ones_like(z_flatten)) 99 | log_prob_y = prior_y.log_prob(y_flatten) 100 | log_prob_z = prior_z.log_prob(z_flatten) 101 | 102 | # ### The MLE loss 103 | 104 | loss = log_prob_y + log_prob_z + logdet 105 | loss = - tf.reduce_mean(loss) 106 | 107 | 108 | # ### The L2 regularization loss 109 | 110 | print('... setting up L2 regularziation') 111 | trainable_variables = tf.trainable_variables() 112 | l2_reg = 0.00001 113 | l2_loss = l2_reg * tf.add_n([ tf.nn.l2_loss(v) for v in tqdm(trainable_variables, total=len(trainable_variables), leave=False)]) 114 | 115 | 116 | # ### Total loss -logp(x) + l2_loss 117 | 118 | loss_per_pixel = loss / c.IMAGE_SIZE / c.IMAGE_SIZE 119 | total_loss = l2_loss + loss_per_pixel 120 | 121 | # it should be moved to main() 122 | #sess.run(tf.global_variables_initializer()) 123 | 124 | ################################################# 125 | """ Trainer """ 126 | ################################################# 127 | 128 | self.lr_ph = tf.placeholder(tf.float32) 129 | print('... setting up optimizer') 130 | optimizer = tf.train.AdamOptimizer(self.lr_ph) 131 | self.train_op = optimizer.minimize(total_loss) 132 | 133 | 134 | # it should be moved to main() 135 | # ## Initialize Actnorms using DDI 136 | """ 137 | sess.run(tf.global_variables_initializer()) 138 | nets.initialize_actnorms( 139 | sess, 140 | feed_dict_fn=lambda: {beta_ph: 1.0}, 141 | actnorm_layers=actnorm_layers, 142 | num_steps=10, 143 | ) 144 | """ 145 | 146 | 147 | # ## Train model, define metrics and trainer 148 | 149 | print('... setting up training metrics') 150 | self.metrics = utils.Metrics(50, metrics_tensors={"total_loss": total_loss, "loss_per_pixel": loss_per_pixel, "l2_loss": l2_loss}) 151 | self.plot_metrics_hook = utils.PlotMetricsHook(self.metrics, step=1000) 152 | 153 | ################################################# 154 | """ Backward Flow """ 155 | ################################################# 156 | 157 | with tf.variable_scope('decoder') as scope: 158 | self.x_reconst_train = self.decoder((y, logdet, z)) 159 | 160 | sample_y_flatten = prior_y.sample() 161 | sample_y = tf.reshape(sample_y_flatten, y.shape.as_list()) 162 | sample_z = tf.reshape(prior_z.sample(), z.shape.as_list()) 163 | sampled_logdet = prior_y.log_prob(sample_y_flatten) 164 | 165 | with tf.variable_scope(scope, reuse=True): 166 | self.x_sampled_train = self.decoder((sample_y, sampled_logdet, sample_z)) 167 | return 168 | 169 | def build_graph_test(self, x=None): 170 | 171 | with tf.variable_scope('encoder') as scope: 172 | y, logdet, z = self.encoder(x) 173 | 174 | self.y, self.logdet, self.z = y, logdet, z 175 | logdet_pior_tmp = tf.zeros_like(logdet) 176 | 177 | self.y_ph = tf.placeholder(tf.float32, y.shape.as_list()) 178 | self.z_ph = tf.placeholder(tf.float32, z.shape.as_list()) 179 | 180 | with tf.variable_scope('decoder') as scope: 181 | self.x_reconst = self.decoder((y, logdet, z)) 182 | 183 | with tf.variable_scope(scope, reuse=True): 184 | self.x_sampled = self.decoder((self.y_ph, logdet_pior_tmp, self.z_ph)) 185 | 186 | """ test code 187 | with tf.variable_scope('encoder') as scope: 188 | with tf.variable_scope(scope, reuse=True): 189 | y_2, logdet_2, z_2 = self.encoder(x_reconst) 190 | 191 | with tf.variable_scope('decoder') as scope: 192 | with tf.variable_scope(scope, reuse=True): 193 | x_reconst_2 = self.decoder((y_2, logdet_2, z_2)) 194 | self.x_reconst_2 = self.quantize_image_layer.to_uint8(x_reconst_2) 195 | """ 196 | 197 | """ 198 | x = x_reconst_2 199 | for i in range(1): 200 | print('reconst:',i) 201 | with tf.variable_scope('encoder') as scope: 202 | with tf.variable_scope(scope, reuse=True): 203 | y, logdet, z = self.encoder(x) 204 | 205 | with tf.variable_scope('decoder') as scope: 206 | with tf.variable_scope(scope, reuse=True): 207 | x = self.decoder((y, logdet, z)) 208 | 209 | self.x_reconst_n = self.quantize_image_layer.to_uint8(x) 210 | """ 211 | 212 | return 213 | -------------------------------------------------------------------------------- /glow/config_glow.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os,sys 4 | 5 | # to import models 6 | #sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../CNN') 7 | 8 | 9 | """ 10 | #tf.flags.DEFINE_string("data_set", "CIFAR10", "SVHN /CIFAR10 / CelebA") 11 | tf.flags.DEFINE_string("data_set", "CelebA", "SVHN /CIFAR10 / CelebA") 12 | tf.flags.DEFINE_boolean("restore", False, "restore from the last check point") 13 | """ 14 | tf.flags.DEFINE_string("dir_root", "./out/", "") 15 | tf.flags.DEFINE_string("file_ckpt", "", "") 16 | 17 | FLAGS = tf.flags.FLAGS 18 | 19 | if FLAGS.data_set == "CelebA": 20 | WIDTH_RESNET = 128 21 | N_FLOW_STEPS = 22 22 | N_FLOW_STEPS = 32 23 | N_FLOW_SCALES = 4 24 | IMAGE_SIZE = 64 25 | BATCH_SIZE = 4 # 26 | BATCH_SIZE = 128 # just for interpolation 27 | BATCH_SIZE = 16 # 28 | 29 | elif FLAGS.data_set == "SVHN" or FLAGS.data_set == 'CIFAR10': 30 | WIDTH_RESNET = 128 31 | N_FLOW_STEPS = 22 32 | N_FLOW_SCALES = 3 33 | IMAGE_SIZE = 32 34 | BATCH_SIZE = 128 35 | else: 36 | raise ValueError 37 | 38 | if FLAGS.is_aug: 39 | dir_logs = os.path.join(FLAGS.dir_root, FLAGS.data_set + '_aug') 40 | else: 41 | dir_logs = os.path.join(FLAGS.dir_root, FLAGS.data_set) 42 | dir_logs = os.path.join(dir_logs, "w_%d__step_%d__scale_%d__b_%d"%(WIDTH_RESNET, N_FLOW_STEPS, N_FLOW_SCALES, BATCH_SIZE)) 43 | 44 | FLAGS.file_ckpt = os.path.join(dir_logs,"model.ckpt") 45 | 46 | print('checkpoint:', FLAGS.file_ckpt) 47 | os.makedirs(dir_logs, exist_ok=True) 48 | 49 | IS_DRYRUN = False 50 | if IS_DRYRUN: 51 | sess = tf.InteractiveSession() 52 | a = tf.Variable(0) 53 | sess.run(tf.global_variables_initializer()) 54 | saver = tf.train.Saver() 55 | save_path = saver.save(sess, FLAGS.file_ckpt) 56 | print("Dryrun ... Model will be saved in path: %s" % save_path) 57 | sys.exit('exit dry run') 58 | 59 | -------------------------------------------------------------------------------- /glow/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import os, sys 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | from scipy.stats import norm 9 | import matplotlib.pyplot as plt 10 | 11 | import nets 12 | import flow_layers as fl 13 | 14 | tf.flags.DEFINE_string("data_set", "CIFAR10", "SVHN /CIFAR10") 15 | tf.flags.DEFINE_boolean("restore", False, "restore from the last check point") 16 | tf.flags.DEFINE_boolean("is_aug", True, "data augmentation") 17 | FLAGS = tf.flags.FLAGS 18 | 19 | import config_glow as c 20 | from Glow import Glow 21 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/util/') 22 | import utils 23 | 24 | tf.set_random_seed(0) 25 | 26 | 27 | epoch = 0 28 | 29 | 30 | def trainer( n_train, lr, m, sess, saver, n_steps=1000, beta=1.0, do_save_model=False, do_save_image=True): 31 | 32 | for i in range(n_train): 33 | utils.trainer( 34 | sess, 35 | num_steps=n_steps, 36 | train_op=m.train_op, 37 | feed_dict_fn=lambda: {m.lr_ph: lr, m.beta_ph: beta}, 38 | metrics=[m.metrics], 39 | hooks=[m.plot_metrics_hook] 40 | ) 41 | 42 | global epoch 43 | epoch += 1 44 | 45 | if do_save_model: 46 | print("... saving model: %s" % FLAGS.file_ckpt) 47 | save_path = saver.save(sess, FLAGS.file_ckpt) 48 | 49 | if do_save_image: 50 | fig = c.dir_logs + '/' + 'fig_glow__%d'%(epoch) 51 | print("... saving figure: %s" % fig) 52 | 53 | plt.subplot(121) 54 | plt.imshow(utils.plot_grid(m.x_sampled_train).eval({m.lr_ph: 0.0, m.beta_ph: 0.9})) 55 | plt.subplot(122) 56 | plt.imshow(utils.plot_grid(m.x_sampled_train).eval({m.lr_ph: 0.0, m.beta_ph: 1.0})) 57 | #plt.show() 58 | plt.savefig(fig) 59 | 60 | 61 | def main(): 62 | 63 | sess = tf.InteractiveSession() 64 | 65 | if FLAGS.data_set == "SVHN": 66 | from HandleSVHN import HandleSVHN 67 | d = HandleSVHN() 68 | (x,_), (_,_) = d.get_data(batch_size=c.BATCH_SIZE,image_size=c.IMAGE_SIZE) 69 | 70 | elif FLAGS.data_set == "CIFAR10": 71 | from HandleCIFAR10 import HandleCIFAR10 72 | d = HandleCIFAR10() 73 | (x,_), (_,_) = d.get_data(batch_size=c.BATCH_SIZE,image_size=c.IMAGE_SIZE) 74 | 75 | else: 76 | raise ValueError 77 | 78 | scope_name = 'scope_glow' 79 | 80 | with tf.variable_scope(scope_name ) as scope: 81 | m = Glow() 82 | m.build_graph_train(x) 83 | 84 | saver = tf.train.Saver() 85 | 86 | if FLAGS.restore: 87 | print("... restore with:", c.FLAGS.file_ckpt) 88 | saver.restore(sess, c.FLAGS.file_ckpt) 89 | else: 90 | sess.run(tf.global_variables_initializer()) 91 | nets.initialize_actnorms( 92 | sess, 93 | feed_dict_fn=lambda: {m.beta_ph: 1.0}, 94 | actnorm_layers=m.actnorm_layers, 95 | num_steps=10, 96 | ) 97 | 98 | 99 | sess.run(m.train_op, feed_dict={m.lr_ph: 0.0, m.beta_ph: 1.0}) 100 | 101 | # ### Train model 102 | # 103 | # * We start from small learning rate (warm-up) 104 | 105 | trainer( 1, 0.0001, m, sess, saver, n_steps=100 ) 106 | 107 | trainer( 5, 0.0005, m, sess, saver, n_steps=100 ) 108 | 109 | trainer( 5, 0.0001, m, sess, saver) 110 | 111 | if FLAGS.is_aug: trainer( 5, 0.0001, m, sess, saver) 112 | 113 | trainer( 5, 0.00005, m, sess, saver) 114 | 115 | if FLAGS.is_aug: trainer( 5, 0.00005, m, sess, saver) 116 | 117 | trainer( 5, 0.0001, m, sess, saver) 118 | 119 | trainer( 1, 0.0001, m, sess, saver, n_steps=0, do_save_model=True) 120 | 121 | plot_metrics_hook.run() 122 | 123 | if __name__ == "__main__": 124 | 125 | main() 126 | -------------------------------------------------------------------------------- /glow/nets.py: -------------------------------------------------------------------------------- 1 | from typing import List, Callable, Dict, Any, Tuple, NamedTuple 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | import tensorflow.contrib.layers as tf_layers 6 | from tensorflow.python.ops import template as template_ops 7 | from tqdm import tqdm 8 | 9 | import flow_layers as fl 10 | import tf_ops 11 | import tf_ops as ops 12 | 13 | 14 | K = tf.keras.backend 15 | keras = tf.keras 16 | 17 | 18 | def simple_resnet_template_fn( 19 | name: str, 20 | activation_fn=tf.nn.relu, 21 | units_factor: int = 2, 22 | num_blocks: int = 1, 23 | units_width: int = 0, 24 | selu_reg_scale: float = 0.0, 25 | skip_connection: bool = True 26 | ): 27 | """ 28 | Creates simple Resnet shallow network. Note that this function will return a 29 | tensorflow template. 30 | Args: 31 | name: a scope name of the network 32 | activation_fn: activation function used after each conv layer 33 | units_factor: a base scale of the numbers of units in the resnet block. 34 | The number of units is computed as units_factor * num_channels. 35 | num_blocks: num resnet blocks 36 | units_width: number of units in the resnet. if 0 then units_factor 37 | is used to estimate num_units in the conv2d 38 | selu_reg_scale: conv weights selu like regularization 39 | skip_connection: whether to use skip connections or not 40 | 41 | Returns: 42 | a template function 43 | """ 44 | 45 | if selu_reg_scale == 0: 46 | reg_fn = lambda: None 47 | else: 48 | reg_fn = lambda: tf_ops.conv2d_selu_regularizer(selu_reg_scale) 49 | 50 | def _shift_and_log_scale_fn(x: tf.Tensor): 51 | shape = K.int_shape(x) 52 | num_channels = shape[3] 53 | num_units = num_channels * units_factor 54 | if units_width != 0: 55 | num_units = units_width 56 | 57 | h = x 58 | for u in range(num_blocks): 59 | with tf.variable_scope(f"ResnetBlock{u}"): 60 | h_input = h 61 | # nn definition 62 | h = tf_layers.conv2d( 63 | inputs=h_input, 64 | num_outputs=num_units, 65 | kernel_size=3, 66 | activation_fn=activation_fn, 67 | weights_regularizer=reg_fn() 68 | ) 69 | h = tf_layers.conv2d( 70 | inputs=h, 71 | num_outputs=num_units, 72 | kernel_size=1, 73 | activation_fn=None, 74 | ) 75 | if skip_connection: 76 | 77 | if num_units != K.int_shape(h_input)[3]: 78 | h_input = tf_layers.conv2d( 79 | inputs=h_input, 80 | num_outputs=num_units, 81 | kernel_size=1, 82 | activation_fn=activation_fn, 83 | weights_regularizer=reg_fn() 84 | ) 85 | 86 | h = h + h_input 87 | 88 | h = activation_fn(h) 89 | 90 | # create shift and log_scale with (almost) zero initialization 91 | shift_log_scale = tf_layers.conv2d( 92 | inputs=h, 93 | num_outputs=2 * num_channels, 94 | weights_initializer=tf.variance_scaling_initializer(scale=0.001), 95 | kernel_size=3, 96 | activation_fn=None, 97 | normalizer_fn=None, 98 | ) 99 | shift = shift_log_scale[:, :, :, :num_channels] 100 | log_scale = shift_log_scale[:, :, :, num_channels:] 101 | log_scale = tf.clip_by_value(log_scale, -15.0, 15.0) 102 | return shift, log_scale 103 | 104 | return template_ops.make_template(name, _shift_and_log_scale_fn) 105 | 106 | 107 | class TemplateFn: 108 | def __init__( 109 | self, 110 | params: Dict[str, Any], 111 | template_fn: Callable[[str], Any] 112 | ): 113 | self._params = params 114 | self._template_fn = template_fn 115 | 116 | def create_template_fn(self, name: str): 117 | return self._template_fn(name=name, **self._params) 118 | 119 | 120 | class ResentTemplate(TemplateFn): 121 | def __init__( 122 | self, 123 | activation_fn=tf.nn.relu, 124 | units_factor: int = 2, 125 | num_blocks: int = 1, 126 | units_width: int = 0, 127 | skip_connection: bool = True, 128 | selu_reg_scale: float = 0.001, 129 | ) -> None: 130 | params = { 131 | "activation_fn": activation_fn, 132 | "units_factor": units_factor, 133 | "num_blocks": num_blocks, 134 | "units_width": units_width, 135 | "skip_connection": skip_connection, 136 | "selu_reg_scale": selu_reg_scale, 137 | } 138 | super().__init__( 139 | params=params, 140 | template_fn=simple_resnet_template_fn 141 | ) 142 | 143 | 144 | class OpenAITemplate(NamedTuple): 145 | """ 146 | A shallow neural network used by GLOW paper: 147 | * https://github.com/openai/glow 148 | 149 | activation_fn: activation function used after each conv layer 150 | width: number of filters in the shallow network 151 | """ 152 | activation_fn: Callable[[tf.Tensor], tf.Tensor] = tf.nn.relu 153 | width: int = 32 154 | 155 | def create_template_fn( 156 | self, 157 | name: str, 158 | ) -> Callable[[tf.Tensor], Tuple[tf.Tensor, tf.Tensor]]: 159 | """ 160 | Creates simple shallow network. Note that this function will return a 161 | tensorflow template. 162 | Args: 163 | name: a scope name of the network 164 | Returns: 165 | a template function 166 | """ 167 | 168 | def _shift_and_log_scale_fn(x: tf.Tensor): 169 | shape = K.int_shape(x) 170 | num_channels = shape[3] 171 | 172 | with tf.variable_scope("BlockNN"): 173 | h = x 174 | h = self.activation_fn(ops.conv2d("l_1", h, self.width)) 175 | h = self.activation_fn( 176 | ops.conv2d("l_2", h, self.width, filter_size=[1, 1])) 177 | # create shift and log_scale with zero initialization 178 | shift_log_scale = ops.conv2d_zeros( 179 | "l_last", h, 2 * num_channels 180 | ) 181 | shift = shift_log_scale[:, :, :, 0::2] 182 | log_scale = shift_log_scale[:, :, :, 1::2] 183 | log_scale = tf.clip_by_value(log_scale, -15.0, 15.0) 184 | return shift, log_scale 185 | 186 | return template_ops.make_template(name, _shift_and_log_scale_fn) 187 | 188 | 189 | def step_flow( 190 | name: str, 191 | shift_and_log_scale_fn: Callable[[tf.Tensor], tf.Tensor] 192 | ) -> Tuple[fl.ChainLayer, fl.ActnormLayer]: 193 | """Create single step of the Glow model: 194 | 195 | 1. actnorm 196 | 2. invertible conv 197 | 3. affine coupling layer 198 | 199 | Returns: 200 | step_layer: a flow layer which perform 1-3 operations 201 | actnorm: a reference of actnorm layer from step 1. This reference can be 202 | used to initialize this layer using data dependent initialization 203 | """ 204 | actnorm = fl.ActnormLayer() 205 | layers = [ 206 | actnorm, 207 | fl.InvertibleConv1x1Layer(), 208 | fl.AffineCouplingLayer(shift_and_log_scale_fn=shift_and_log_scale_fn), 209 | ] 210 | return fl.ChainLayer(layers, name=name), actnorm 211 | 212 | 213 | def initialize_actnorms( 214 | sess: tf.Session(), 215 | feed_dict_fn: Callable[[], Dict[tf.Tensor, np.ndarray]], 216 | actnorm_layers: List[fl.ActnormLayer], 217 | num_steps: int = 100, 218 | num_init_iterations: int = 10, 219 | ) -> None: 220 | """Initialize actnorm layers using data dependent initialization 221 | 222 | Args: 223 | sess: an instance of tf.Session 224 | feed_dict_fn: a feed dict function which return feed_dict to the tensorflow 225 | sess.run function. 226 | actnorm_layers: a list of actnorms to initialize 227 | num_steps: number of batches to used for iterative initialization. 228 | num_init_iterations: a get_ddi_init_ops parameter. For more details 229 | see the implementation. 230 | """ 231 | for actnorm_layer in tqdm(actnorm_layers): 232 | init_op = actnorm_layer.get_ddi_init_ops(num_init_iterations) 233 | for i in range(num_steps): 234 | sess.run(init_op, feed_dict=feed_dict_fn()) 235 | 236 | 237 | def create_simple_flow( 238 | num_steps: int = 1, 239 | num_scales: int = 3, 240 | num_bits: int = 5, 241 | template_fn: Any = ResentTemplate() 242 | ) -> Tuple[List[fl.FlowLayer], List[fl.ActnormLayer]]: 243 | """Create Glow model. This implementation may slightly differ from the 244 | official one. For example the last layer here is the fl.FactorOutLayer 245 | 246 | Args: 247 | num_steps: number of steps per single scale, a K parameter from the paper 248 | num_scales: number of scales, a L parameter from the paper. Each scale 249 | reduces the tensor spatial dimension by 2. 250 | num_bits: input image quantization 251 | template_fn: a template function used in AffineCoupling layer 252 | 253 | Returns: 254 | layers: a list of layers which define normalizing flow 255 | actnorms: a list of actnorm layers which can be initialized using data 256 | dependent initialization. See: initialize_actnorms() function. 257 | """ 258 | layers = [fl.QuantizeImage(num_bits=num_bits)] 259 | actnorm_layers = [] 260 | for scale in range(num_scales): 261 | scale_name = f"Scale{scale+1}" 262 | scale_steps = [] 263 | for s in range(num_steps): 264 | name = f"Step{s+1}" 265 | step_layer, actnorm_layer = step_flow( 266 | name=name, 267 | shift_and_log_scale_fn=template_fn.create_template_fn(name) 268 | ) 269 | scale_steps.append(step_layer) 270 | actnorm_layers.append(actnorm_layer) 271 | 272 | layers += [ 273 | fl.SqueezingLayer(name=scale_name), 274 | fl.ChainLayer(scale_steps, name=scale_name), 275 | fl.FactorOutLayer(name=scale_name), 276 | ] 277 | 278 | return layers, actnorm_layers 279 | -------------------------------------------------------------------------------- /glow/tf_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is copied from GLOW github: 3 | https://github.com/openai/glow/blob/master/tfops.py 4 | 5 | And is used only by one class in nets.py: OpenAITemplate 6 | """ 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | from tensorflow.contrib.framework.python.ops import add_arg_scope, arg_scope 11 | 12 | SELU_CONV2D_REG_LOSS = "selu_conv2d_reg_loss" 13 | 14 | 15 | def default_initial_value(shape, std=0.05): 16 | return tf.random_normal(shape, 0., std) 17 | 18 | 19 | def default_initializer(std=0.05): 20 | return tf.random_normal_initializer(0., std) 21 | 22 | 23 | def int_shape(x): 24 | if str(x.get_shape()[0]) != '?': 25 | return list(map(int, x.get_shape())) 26 | return [-1] + list(map(int, x.get_shape()[1:])) 27 | 28 | 29 | # wrapper tf.get_variable, augmented with 'init' functionality 30 | # Get variable with data dependent init 31 | 32 | 33 | @add_arg_scope 34 | def get_variable_ddi(name, shape, initial_value, dtype=tf.float32, init=False, 35 | trainable=True): 36 | w = tf.get_variable(name, shape, dtype, None, trainable=trainable) 37 | if init: 38 | w = w.assign(initial_value) 39 | with tf.control_dependencies([w]): 40 | return w 41 | return w 42 | 43 | 44 | # Activation normalization 45 | # Convenience function that does centering+scaling 46 | 47 | @add_arg_scope 48 | def actnorm(name, x, scale=1., logdet=None, logscale_factor=3., 49 | batch_variance=False, reverse=False, init=False, trainable=True): 50 | if arg_scope([get_variable_ddi], trainable=trainable): 51 | if not reverse: 52 | x = actnorm_center(name + "_center", x, reverse) 53 | x = actnorm_scale(name + "_scale", x, scale, logdet, 54 | logscale_factor, batch_variance, reverse, init) 55 | if logdet != None: 56 | x, logdet = x 57 | else: 58 | x = actnorm_scale(name + "_scale", x, scale, logdet, 59 | logscale_factor, batch_variance, reverse, init) 60 | if logdet != None: 61 | x, logdet = x 62 | x = actnorm_center(name + "_center", x, reverse) 63 | if logdet != None: 64 | return x, logdet 65 | return x 66 | 67 | 68 | # Activation normalization 69 | 70 | 71 | @add_arg_scope 72 | def actnorm_center(name, x, reverse=False): 73 | shape = x.get_shape() 74 | with tf.variable_scope(name): 75 | assert len(shape) == 2 or len(shape) == 4 76 | if len(shape) == 2: 77 | x_mean = tf.reduce_mean(x, [0], keepdims=True) 78 | b = get_variable_ddi( 79 | "b", (1, int_shape(x)[1]), initial_value=-x_mean) 80 | elif len(shape) == 4: 81 | x_mean = tf.reduce_mean(x, [0, 1, 2], keepdims=True) 82 | b = get_variable_ddi( 83 | "b", (1, 1, 1, int_shape(x)[3]), initial_value=-x_mean) 84 | 85 | if not reverse: 86 | x += b 87 | else: 88 | x -= b 89 | 90 | return x 91 | 92 | 93 | # Activation normalization 94 | @add_arg_scope 95 | def actnorm_scale(name, x, scale=1., logdet=None, logscale_factor=3., 96 | batch_variance=False, reverse=False, init=False, 97 | trainable=True): 98 | shape = x.get_shape() 99 | with tf.variable_scope(name), arg_scope([get_variable_ddi], 100 | trainable=trainable): 101 | assert len(shape) == 2 or len(shape) == 4 102 | if len(shape) == 2: 103 | x_var = tf.reduce_mean(x ** 2, [0], keepdims=True) 104 | logdet_factor = 1 105 | _shape = (1, int_shape(x)[1]) 106 | 107 | elif len(shape) == 4: 108 | x_var = tf.reduce_mean(x ** 2, [0, 1, 2], keepdims=True) 109 | logdet_factor = int(shape[1]) * int(shape[2]) 110 | _shape = (1, 1, 1, int_shape(x)[3]) 111 | 112 | if batch_variance: 113 | x_var = tf.reduce_mean(x ** 2, keepdims=True) 114 | 115 | 116 | logs = get_variable_ddi("logs", _shape, initial_value=tf.log( 117 | scale / (tf.sqrt( 118 | x_var) + 1e-6)) / logscale_factor) * logscale_factor 119 | if not reverse: 120 | x = x * tf.exp(logs) 121 | else: 122 | x = x * tf.exp(-logs) 123 | 124 | if logdet != None: 125 | dlogdet = tf.reduce_sum(logs) * logdet_factor 126 | if reverse: 127 | dlogdet *= -1 128 | return x, logdet + dlogdet 129 | 130 | return x 131 | 132 | 133 | # Linear layer with layer norm 134 | @add_arg_scope 135 | def linear(name, x, width, do_weightnorm=True, do_actnorm=True, 136 | initializer=None, scale=1.): 137 | initializer = initializer or default_initializer() 138 | with tf.variable_scope(name): 139 | n_in = int(x.get_shape()[1]) 140 | w = tf.get_variable("W", [n_in, width], 141 | tf.float32, initializer=initializer) 142 | if do_weightnorm: 143 | w = tf.nn.l2_normalize(w, [0]) 144 | x = tf.matmul(x, w) 145 | x += tf.get_variable("b", [1, width], 146 | initializer=tf.zeros_initializer()) 147 | if do_actnorm: 148 | x = actnorm("actnorm", x, scale) 149 | return x 150 | 151 | 152 | # Linear layer with zero init 153 | @add_arg_scope 154 | def linear_zeros(name, x, width, logscale_factor=3): 155 | with tf.variable_scope(name): 156 | n_in = int(x.get_shape()[1]) 157 | w = tf.get_variable("W", [n_in, width], tf.float32, 158 | initializer=tf.zeros_initializer()) 159 | x = tf.matmul(x, w) 160 | x += tf.get_variable("b", [1, width], 161 | initializer=tf.zeros_initializer()) 162 | x *= tf.exp(tf.get_variable("logs", 163 | [1, width], 164 | initializer=tf.zeros_initializer()) * logscale_factor) 165 | return x 166 | 167 | 168 | # Slow way to add edge padding 169 | def add_edge_padding(x, filter_size): 170 | assert filter_size[0] % 2 == 1 171 | if filter_size[0] == 1 and filter_size[1] == 1: 172 | return x 173 | a = (filter_size[0] - 1) // 2 # vertical padding size 174 | b = (filter_size[1] - 1) // 2 # horizontal padding size 175 | if True: 176 | x = tf.pad(x, [[0, 0], [a, a], [b, b], [0, 0]]) 177 | name = "_".join([str(dim) for dim in [a, b, *int_shape(x)[1:3]]]) 178 | pads = tf.get_collection(name) 179 | if not pads: 180 | pad = np.zeros([1] + int_shape(x)[1:3] + [1], dtype='float32') 181 | pad[:, :a, :, 0] = 1. 182 | pad[:, -a:, :, 0] = 1. 183 | pad[:, :, :b, 0] = 1. 184 | pad[:, :, -b:, 0] = 1. 185 | pad = tf.convert_to_tensor(pad) 186 | tf.add_to_collection(name, pad) 187 | else: 188 | pad = pads[0] 189 | pad = tf.tile(pad, [tf.shape(x)[0], 1, 1, 1]) 190 | x = tf.concat([x, pad], axis=3) 191 | else: 192 | pad = tf.pad(tf.zeros_like(x[:, :, :, :1]) - 1, 193 | [[0, 0], [a, a], [b, b], [0, 0]]) + 1 194 | x = tf.pad(x, [[0, 0], [a, a], [b, b], [0, 0]]) 195 | x = tf.concat([x, pad], axis=3) 196 | return x 197 | 198 | 199 | @add_arg_scope 200 | def conv2d(name, x, width, filter_size=[3, 3], stride=[1, 1], pad="SAME", 201 | do_weightnorm=False, do_actnorm=True, context1d=None, skip=1, 202 | edge_bias=True): 203 | with tf.variable_scope(name): 204 | if edge_bias and pad == "SAME": 205 | x = add_edge_padding(x, filter_size) 206 | pad = 'VALID' 207 | 208 | n_in = int(x.get_shape()[3]) 209 | 210 | stride_shape = [1] + stride + [1] 211 | filter_shape = filter_size + [n_in, width] 212 | w = tf.get_variable("W", filter_shape, tf.float32, 213 | initializer=default_initializer()) 214 | if do_weightnorm: 215 | w = tf.nn.l2_normalize(w, [0, 1, 2]) 216 | if skip == 1: 217 | x = tf.nn.conv2d(x, w, stride_shape, pad, data_format='NHWC') 218 | else: 219 | assert stride[0] == 1 and stride[1] == 1 220 | x = tf.nn.atrous_conv2d(x, w, skip, pad) 221 | if do_actnorm: 222 | x = actnorm("actnorm", x) 223 | else: 224 | x += tf.get_variable("b", [1, 1, 1, width], 225 | initializer=tf.zeros_initializer()) 226 | 227 | if context1d != None: 228 | x += tf.reshape(linear("context", context1d, 229 | width), [-1, 1, 1, width]) 230 | return x 231 | 232 | 233 | @add_arg_scope 234 | def conv2d_zeros(name, x, width, filter_size=[3, 3], stride=[1, 1], pad="SAME", 235 | logscale_factor=3, skip=1, edge_bias=True): 236 | with tf.variable_scope(name): 237 | if edge_bias and pad == "SAME": 238 | x = add_edge_padding(x, filter_size) 239 | pad = 'VALID' 240 | 241 | n_in = int(x.get_shape()[3]) 242 | stride_shape = [1] + stride + [1] 243 | filter_shape = filter_size + [n_in, width] 244 | w = tf.get_variable("W", filter_shape, tf.float32, 245 | initializer=tf.zeros_initializer()) 246 | if skip == 1: 247 | x = tf.nn.conv2d(x, w, stride_shape, pad, data_format='NHWC') 248 | else: 249 | assert stride[0] == 1 and stride[1] == 1 250 | x = tf.nn.atrous_conv2d(x, w, skip, pad) 251 | x += tf.get_variable("b", [1, 1, 1, width], 252 | initializer=tf.zeros_initializer()) 253 | x *= tf.exp(tf.get_variable("logs", 254 | [1, width], 255 | initializer=tf.zeros_initializer()) * logscale_factor) 256 | return x 257 | 258 | 259 | K = tf.keras.backend 260 | keras = tf.keras 261 | 262 | 263 | # inspired by loss of VAEs 264 | def fc_selu_reg(x: tf.Tensor, mu: float) -> tf.Tensor: 265 | # average over filter size 266 | mean = K.mean(x, axis=0) 267 | tau_sqr = K.mean(K.square(x), axis=0) 268 | # average over batch size 269 | mean_loss = K.mean(K.square(mean)) 270 | tau_loss = K.mean(tau_sqr - K.log(tau_sqr + K.epsilon())) 271 | return mu * (mean_loss + tau_loss) 272 | 273 | 274 | def conv2d_selu_regularizer(scale: float): 275 | def _regularizer_fn(weights: tf.Tensor) -> tf.Tensor: 276 | shape = K.int_shape(weights) 277 | num_filters = shape[-1] 278 | weights = K.reshape(weights, shape=[-1, num_filters]) 279 | with tf.name_scope("SELUConv2DRegLoss"): 280 | loss = fc_selu_reg(weights, scale) 281 | tf.add_to_collection(SELU_CONV2D_REG_LOSS, loss) 282 | return loss 283 | 284 | return _regularizer_fn 285 | -------------------------------------------------------------------------------- /glow/util/HandleCIFAR10.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import os.path 5 | 6 | import scipy.io 7 | import scipy.io.wavfile 8 | from imageio import imread # for scipy > 1.3.0 9 | import tensorflow as tf 10 | 11 | from tqdm import tqdm 12 | import numpy as np 13 | 14 | from utils import numpy_array_to_dataset 15 | from cifar10 import load_cifar10, DATA_DIR, NUM_EXAMPLES_TRAIN, NUM_EXAMPLES_TEST 16 | 17 | tf.flags.DEFINE_bool("do_write_tfrecord", False, "if True, write tfrecords.") 18 | tf.flags.DEFINE_bool("do_write_npy", True, "if True, write npy.") 19 | 20 | FLAGS = tf.flags.FLAGS 21 | 22 | class HandleCIFAR10(object): 23 | 24 | def __init__(self): 25 | self.n_train = NUM_EXAMPLES_TRAIN 26 | self.n_valid = 0 27 | self.n_test = NUM_EXAMPLES_TEST 28 | 29 | def get_ndarray(self): 30 | 31 | (x_train, y_train), (x_test, y_test) = load_cifar10(True) 32 | # y is integer at this moment 33 | 34 | #y_train = np.identity(10)[y_train] 35 | #y_test = np.identity(10)[y_test ] 36 | 37 | print(x_train[:2]) 38 | print(x_train.shape) 39 | print(y_train.shape) 40 | print(x_test.shape) 41 | print(y_test.shape) 42 | return (x_train, y_train), (x_test, y_test) 43 | #sys.exit 44 | 45 | def get_data(self, 46 | image_size: int = 32, 47 | batch_size: int = 16, 48 | #batch_size: int = 4, 49 | buffer_size: int = 512, 50 | num_parallel_batches: int = 16, # org 51 | #num_parallel_batches: int = 1, 52 | ): 53 | 54 | 55 | def get_dataset_from_ndarray(): 56 | (x_train, y_train), (x_test, y_test) = self.get_ndarray() 57 | 58 | def _zip(x,y): 59 | """ x,y are np.array """ 60 | x_dataset = numpy_array_to_dataset(x, 61 | buffer_size=buffer_size, batch_size=batch_size, num_parallel_batches=num_parallel_batches) 62 | 63 | y_dataset = numpy_array_to_dataset(y, 64 | buffer_size=buffer_size, batch_size=batch_size, num_parallel_batches=num_parallel_batches) 65 | 66 | return tf.data.Dataset.zip((x_dataset, y_dataset)) 67 | 68 | return _zip(x_train, y_train), _zip(x_test, y_test) 69 | 70 | def get_dataset_from_tfrecord(): 71 | DATASET_SEED = 1 72 | dir_rood = DATA_DIR + '/seed' + str(DATASET_SEED) 73 | train_tfrecord = dir_rood + '/' + 'labeled_train.tfrecords' 74 | test_tfrecord = dir_rood + '/' + 'test.tfrecords' 75 | print('... set up to read from', dir_rood) 76 | return tf.data.TFRecordDataset(train_tfrecord), tf.data.TFRecordDataset(test_tfrecord) 77 | 78 | def apply_parser(dataset): 79 | 80 | def preprocess(example): 81 | features = tf.parse_single_example( 82 | example, 83 | features={ 84 | "image" : tf.FixedLenFeature([32 * 32 * 3], tf.float32), 85 | "label": tf.FixedLenFeature([], tf.int64) 86 | } 87 | ) 88 | 89 | image = tf.reshape(features["image"], [32, 32, 3]) 90 | label = features["label"] 91 | 92 | return image, label 93 | 94 | dataset = dataset.apply( 95 | tf.contrib.data.map_and_batch( 96 | map_func=preprocess, 97 | batch_size=batch_size, 98 | num_parallel_batches=num_parallel_batches, 99 | drop_remainder=True, 100 | ) 101 | ) 102 | return dataset 103 | 104 | def read(dataset): 105 | 106 | 107 | if buffer_size > 0: 108 | dataset = dataset.apply( 109 | tf.contrib.data.shuffle_and_repeat(buffer_size=buffer_size, count=-1) 110 | ) 111 | 112 | if IS_FROM_TFRECORDS: 113 | dataset = apply_parser(dataset) 114 | 115 | 116 | datayyset = dataset.prefetch(4) 117 | images, label = dataset.make_one_shot_iterator().get_next() 118 | 119 | x = tf.reshape(images, [batch_size, 32, 32, 3]) 120 | #x = tf.image.resize_images( 121 | # x, [image_size, image_size], method=0, align_corners=False 122 | #) 123 | y = tf.one_hot(tf.cast( label, tf.int32), 10) 124 | 125 | return x, y 126 | 127 | IS_FROM_TFRECORDS = True 128 | if IS_FROM_TFRECORDS: 129 | ds_train, ds_test = get_dataset_from_tfrecord() 130 | else: 131 | ds_train, ds_test = get_dataset_from_ndarray() 132 | 133 | return read(ds_train), read(ds_test) 134 | 135 | def prepare(self): 136 | 137 | sys.exit('use cifar10.py instead') 138 | 139 | df_train, df_valid, df_test = get_train_val_test() 140 | 141 | if FLAGS.do_write_tfrecord: 142 | self.write_tfrecord(df_train, TFRECORD_TRAIN) 143 | self.write_tfrecord(df_valid, TFRECORD_VALID) 144 | self.write_tfrecord(df_test, TFRECORD_TEST) 145 | 146 | if FLAGS.do_write_npy: 147 | self.write_npy(df_train, NPY_TRAIN) 148 | self.write_npy(df_valid, NPY_VALID) 149 | self.write_npy(df_test, NPY_TEST) 150 | print('... exit prepare()') 151 | return 152 | 153 | def read_image( self, file_name): 154 | 155 | """ file_name: str. jpg image file name""" 156 | 157 | #image = imread(os.path.join(FLAGS.fn_root, file_name)) 158 | image = imread(os.path.join(PATH_TO_RAWIMG, file_name)) 159 | # 160 | # cropping is performed here. 161 | # this is the same way as RealNVP and kmkolasinski's Glow implemention 162 | # 163 | 164 | # original shape [218, 178, 3] is going to be converted into [144, 144, 3] once, 165 | image = image[40:188, 15:163, :] 166 | image = image.reshape([148, 148, 3]) 167 | 168 | # then convert it into (IMG_SIZE, IMG_SIZE) 169 | import cv2 170 | image = cv2.resize(image, dsize=(IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR) 171 | 172 | return image 173 | 174 | def write_npy(self, df, file_name): 175 | 176 | print('... writing npy file.') 177 | x, y = [],[] 178 | for i, row in enumerate(tqdm(df.itertuples(name=None), leave=False, total=len(df))): 179 | 180 | row = list(row) # tuple to list 181 | img_file_name = row.pop(0) 182 | 183 | image = self.read_image(img_file_name) 184 | x.append(image) 185 | y.append(row) 186 | 187 | x = np.array(x) 188 | y = np.array(y) 189 | print(os.path.join(FLAGS.fn_root, 'x_' + file_name)) 190 | print(os.path.join(FLAGS.fn_root, 'y_' + file_name)) 191 | np.save(os.path.join(FLAGS.fn_root, 'x_' + file_name) , x) 192 | np.save(os.path.join(FLAGS.fn_root, 'y_' + file_name) , y) 193 | return 194 | 195 | def write_tfrecord(self, df, file_name): 196 | 197 | def make_example( image, attributes): 198 | """ image: ndarray whose shape is (144, 144, 3)) 199 | attributes: list of 40 features 200 | """ 201 | 202 | image = image.tostring() 203 | example = tf.train.Example( 204 | features=tf.train.Features( 205 | feature={ 206 | "image" : tf.train.Feature(float_list=tf.train.FloatList(value=image)), 207 | "attributes": tf.train.Feature(int64_list=tf.train.Int64List(value=attributes)) 208 | } 209 | ) 210 | ) 211 | return example 212 | 213 | print('... writing', file_name ) 214 | writer = tf.io.TFRecordWriter(file_name) 215 | 216 | for i, row in enumerate(tqdm(df.itertuples(name=None), leave=False, total=len(df))): 217 | 218 | row = list(row) # tuple to list 219 | img_file_name = row.pop(0) 220 | 221 | image = self.read_image(img_file_name) 222 | ex = make_example(image, row) 223 | writer.write(ex.SerializeToString()) 224 | 225 | #if i > 25: break 226 | writer.close() 227 | #sys.exit() 228 | 229 | 230 | if __name__ == "__main__": 231 | 232 | 233 | d = HandleSVHNData() 234 | 235 | #if FLAGS.is_write_mode: 236 | #d.prepare() 237 | 238 | #print(d.get_ndarray()) 239 | print(d.get_data()) 240 | 241 | IS_VISUALLY_CHECKING = False 242 | ################################################### 243 | """ visually checking """ 244 | ################################################### 245 | if IS_VISUALLY_CHECKING: 246 | sys.exit() 247 | 248 | n = 24 249 | (x_train, y_train), (x_test, y_test) = d.get_data(batch_size=n) 250 | sess = tf.Session() 251 | _x,_y = sess.run([x_train, y_train]) 252 | _x = _x*255 253 | from PIL import Image 254 | for i in range(n): 255 | x,y = _x[i].astype(np.uint8), _y[i] 256 | img = Image.fromarray(x) 257 | img.save('outfile_%s.jpg'%(i)) 258 | print('Label is:', i, y) 259 | -------------------------------------------------------------------------------- /glow/util/HandleSVHN.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import os.path 5 | 6 | import scipy.io 7 | import scipy.io.wavfile 8 | from imageio import imread # for scipy > 1.3.0 9 | import tensorflow as tf 10 | 11 | from tqdm import tqdm 12 | import numpy as np 13 | 14 | from utils import numpy_array_to_dataset 15 | from svhn import load_svhn, DATA_DIR, NUM_EXAMPLES_TRAIN, NUM_EXAMPLES_TEST 16 | 17 | tf.flags.DEFINE_bool("do_write_tfrecord", False, "if True, write tfrecords.") 18 | tf.flags.DEFINE_bool("do_write_npy", True, "if True, write npy.") 19 | 20 | FLAGS = tf.flags.FLAGS 21 | 22 | class HandleSVHNData(object): 23 | 24 | def __init__(self): 25 | self.n_train = NUM_EXAMPLES_TRAIN 26 | self.n_valid = 0 27 | self.n_test = NUM_EXAMPLES_TEST 28 | 29 | def get_ndarray(self): 30 | 31 | (x_train, y_train), (x_test, y_test) = load_svhn(True) 32 | # y is integer at this moment 33 | 34 | #y_train = np.identity(10)[y_train] 35 | #y_test = np.identity(10)[y_test ] 36 | 37 | print(x_train[:2]) 38 | print(x_train.shape) 39 | print(y_train.shape) 40 | print(x_test.shape) 41 | print(y_test.shape) 42 | return (x_train, y_train), (x_test, y_test) 43 | #sys.exit 44 | 45 | def get_data(self, 46 | image_size: int = 32, 47 | batch_size: int = 16, 48 | #batch_size: int = 4, 49 | buffer_size: int = 512, 50 | num_parallel_batches: int = 16, # org 51 | #num_parallel_batches: int = 1, 52 | ): 53 | 54 | 55 | def get_dataset_from_ndarray(): 56 | (x_train, y_train), (x_test, y_test) = self.get_ndarray() 57 | 58 | def _zip(x,y): 59 | """ x,y are np.array """ 60 | x_dataset = numpy_array_to_dataset(x, 61 | buffer_size=buffer_size, batch_size=batch_size, num_parallel_batches=num_parallel_batches) 62 | 63 | y_dataset = numpy_array_to_dataset(y, 64 | buffer_size=buffer_size, batch_size=batch_size, num_parallel_batches=num_parallel_batches) 65 | 66 | return tf.data.Dataset.zip((x_dataset, y_dataset)) 67 | 68 | return _zip(x_train, y_train), _zip(x_test, y_test) 69 | 70 | def get_dataset_from_tfrecord(): 71 | DATASET_SEED = 1 72 | dir_rood = DATA_DIR + '/seed' + str(DATASET_SEED) 73 | train_tfrecord = dir_rood + '/' + 'labeled_train.tfrecords' 74 | test_tfrecord = dir_rood + '/' + 'test.tfrecords' 75 | print('... set up to read from', dir_rood) 76 | return tf.data.TFRecordDataset(train_tfrecord), tf.data.TFRecordDataset(test_tfrecord) 77 | 78 | def apply_parser(dataset): 79 | 80 | def preprocess(example): 81 | features = tf.parse_single_example( 82 | example, 83 | features={ 84 | "image" : tf.FixedLenFeature([32 * 32 * 3], tf.float32), 85 | "label": tf.FixedLenFeature([], tf.int64) 86 | } 87 | ) 88 | 89 | image = tf.reshape(features["image"], [32, 32, 3]) 90 | label = features["label"] 91 | 92 | return image, label 93 | 94 | dataset = dataset.apply( 95 | tf.contrib.data.map_and_batch( 96 | map_func=preprocess, 97 | batch_size=batch_size, 98 | num_parallel_batches=num_parallel_batches, 99 | drop_remainder=True, 100 | ) 101 | ) 102 | return dataset 103 | 104 | def read(dataset): 105 | 106 | 107 | if buffer_size > 0: 108 | dataset = dataset.apply( 109 | tf.contrib.data.shuffle_and_repeat(buffer_size=buffer_size, count=-1) 110 | ) 111 | 112 | if IS_FROM_TFRECORDS: 113 | dataset = apply_parser(dataset) 114 | 115 | 116 | datayyset = dataset.prefetch(4) 117 | images, label = dataset.make_one_shot_iterator().get_next() 118 | 119 | x = tf.reshape(images, [batch_size, 32, 32, 3]) 120 | #x = tf.image.resize_images( 121 | # x, [image_size, image_size], method=0, align_corners=False 122 | #) 123 | y = tf.one_hot(tf.cast( label, tf.int32), 10) 124 | 125 | return x, y 126 | 127 | IS_FROM_TFRECORDS = True 128 | if IS_FROM_TFRECORDS: 129 | ds_train, ds_test = get_dataset_from_tfrecord() 130 | else: 131 | ds_train, ds_test = get_dataset_from_ndarray() 132 | 133 | return read(ds_train), read(ds_test) 134 | 135 | def prepare(self): 136 | 137 | df_train, df_valid, df_test = get_train_val_test() 138 | 139 | if FLAGS.do_write_tfrecord: 140 | self.write_tfrecord(df_train, TFRECORD_TRAIN) 141 | self.write_tfrecord(df_valid, TFRECORD_VALID) 142 | self.write_tfrecord(df_test, TFRECORD_TEST) 143 | 144 | if FLAGS.do_write_npy: 145 | self.write_npy(df_train, NPY_TRAIN) 146 | self.write_npy(df_valid, NPY_VALID) 147 | self.write_npy(df_test, NPY_TEST) 148 | print('... exit prepare()') 149 | return 150 | 151 | def read_image( self, file_name): 152 | 153 | """ file_name: str. jpg image file name""" 154 | 155 | #image = imread(os.path.join(FLAGS.fn_root, file_name)) 156 | image = imread(os.path.join(PATH_TO_RAWIMG, file_name)) 157 | # 158 | # cropping is performed here. 159 | # this is the same way as RealNVP and kmkolasinski's Glow implemention 160 | # 161 | 162 | # original shape [218, 178, 3] is going to be converted into [144, 144, 3] once, 163 | image = image[40:188, 15:163, :] 164 | image = image.reshape([148, 148, 3]) 165 | 166 | # then convert it into (IMG_SIZE, IMG_SIZE) 167 | import cv2 168 | image = cv2.resize(image, dsize=(IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR) 169 | 170 | return image 171 | 172 | def write_npy(self, df, file_name): 173 | 174 | print('... writing npy file.') 175 | x, y = [],[] 176 | for i, row in enumerate(tqdm(df.itertuples(name=None), leave=False, total=len(df))): 177 | 178 | row = list(row) # tuple to list 179 | img_file_name = row.pop(0) 180 | 181 | image = self.read_image(img_file_name) 182 | x.append(image) 183 | y.append(row) 184 | 185 | x = np.array(x) 186 | y = np.array(y) 187 | print(os.path.join(FLAGS.fn_root, 'x_' + file_name)) 188 | print(os.path.join(FLAGS.fn_root, 'y_' + file_name)) 189 | np.save(os.path.join(FLAGS.fn_root, 'x_' + file_name) , x) 190 | np.save(os.path.join(FLAGS.fn_root, 'y_' + file_name) , y) 191 | return 192 | 193 | def write_tfrecord(self, df, file_name): 194 | 195 | def make_example( image, attributes): 196 | """ image: ndarray whose shape is (144, 144, 3)) 197 | attributes: list of 40 features 198 | """ 199 | 200 | image = image.tostring() 201 | example = tf.train.Example( 202 | features=tf.train.Features( 203 | feature={ 204 | "image" : tf.train.Feature(float_list=tf.train.FloatList(value=image)), 205 | "attributes": tf.train.Feature(int64_list=tf.train.Int64List(value=attributes)) 206 | } 207 | ) 208 | ) 209 | return example 210 | 211 | print('... writing', file_name ) 212 | writer = tf.io.TFRecordWriter(file_name) 213 | 214 | for i, row in enumerate(tqdm(df.itertuples(name=None), leave=False, total=len(df))): 215 | 216 | row = list(row) # tuple to list 217 | img_file_name = row.pop(0) 218 | 219 | image = self.read_image(img_file_name) 220 | ex = make_example(image, row) 221 | writer.write(ex.SerializeToString()) 222 | 223 | #if i > 25: break 224 | writer.close() 225 | #sys.exit() 226 | 227 | 228 | if __name__ == "__main__": 229 | 230 | 231 | d = HandleSVHNData() 232 | 233 | #if FLAGS.is_write_mode: 234 | #d.prepare() 235 | 236 | #print(d.get_ndarray()) 237 | print(d.get_data()) 238 | 239 | IS_VISUALLY_CHECKING = False 240 | ################################################### 241 | """ visually checking """ 242 | ################################################### 243 | if IS_VISUALLY_CHECKING: 244 | sys.exit() 245 | 246 | n = 24 247 | (x_train, y_train), (x_test, y_test) = d.get_data(batch_size=n) 248 | sess = tf.Session() 249 | _x,_y = sess.run([x_train, y_train]) 250 | _x = _x*255 251 | from PIL import Image 252 | for i in range(n): 253 | x,y = _x[i].astype(np.uint8), _y[i] 254 | img = Image.fromarray(x) 255 | img.save('outfile_%s.jpg'%(i)) 256 | print('Label is:', i, y) 257 | -------------------------------------------------------------------------------- /glow/util/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | """Routine for decoding the CIFAR-10 binary file format.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | import sys 22 | import tarfile 23 | 24 | import numpy as np 25 | from scipy import linalg 26 | import glob 27 | import pickle 28 | 29 | from six.moves import xrange # pylint: disable=redefined-builtin 30 | from six.moves import urllib 31 | 32 | import tensorflow as tf 33 | 34 | from dataset_utils import * 35 | 36 | DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' 37 | 38 | #tf.app.flags.DEFINE_string('data_dir', 'C:/Users/fx29351/Python/data/CIFAR10', 'where to store the dataset') 39 | #tf.app.flags.DEFINE_integer('num_labeled_examples', 4000, "The number of labeled examples") 40 | tf.app.flags.DEFINE_integer('num_valid_examples', 1000, "The number of validation examples") 41 | #tf.app.flags.DEFINE_integer('dataset_seed', 1, "dataset seed") 42 | FLAGS = tf.app.flags.FLAGS 43 | 44 | # Process images of this size. Note that this differs from the original CIFAR 45 | # image size of 32 x 32. If one alters this number, then the entire model 46 | # architecture will change and any model would need to be retrained. 47 | IMAGE_SIZE = 32 48 | 49 | N_LABELED = 4000 50 | DATASET_SEED = 1 51 | DATA_DIR = 'D:/data/img/CIFAR10' 52 | 53 | # Global constants describing the CIFAR-10 data set. 54 | NUM_CLASSES = 10 55 | NUM_EXAMPLES_TRAIN = 50000 56 | NUM_EXAMPLES_TEST = 10000 57 | 58 | def load_cifar10(): 59 | """Download and extract the tarball from Alex's website.""" 60 | dest_directory = DATA_DIR 61 | if not os.path.exists(dest_directory): 62 | os.makedirs(dest_directory) 63 | filename = DATA_URL.split('/')[-1] 64 | filepath = os.path.join(dest_directory, filename) 65 | if not os.path.exists(filepath): 66 | def _progress(count, block_size, total_size): 67 | sys.stdout.write('\r>> Downloading %s %.1f%%' % 68 | (filename, float(count * block_size) / 69 | float(total_size) * 100.0)) 70 | sys.stdout.flush() 71 | 72 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 73 | print() 74 | statinfo = os.stat(filepath) 75 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 76 | tarfile.open(filepath, 'r:gz').extractall(dest_directory) 77 | 78 | # Training set 79 | print("Loading training data...") 80 | train_images = np.zeros((NUM_EXAMPLES_TRAIN, 3 * 32 * 32), dtype=np.float32) 81 | train_labels = [] 82 | for i, data_fn in enumerate( 83 | sorted(glob.glob(DATA_DIR + '/cifar-10-batches-py/data_batch*'))): 84 | batch = unpickle(data_fn) 85 | train_images[i * 10000:(i + 1) * 10000] = batch['data'] 86 | train_labels.extend(batch['labels']) 87 | 88 | # geosada 170713 for generative model 89 | #train_images = (train_images - 127.5) / 255. 90 | # -> [0,1] 91 | train_images = train_images / 255. 92 | train_labels = np.asarray(train_labels, dtype=np.int64) 93 | 94 | rand_ix = np.random.permutation(NUM_EXAMPLES_TRAIN) 95 | train_images = train_images[rand_ix] 96 | train_labels = train_labels[rand_ix] 97 | 98 | print("Loading test data...") 99 | test = unpickle(DATA_DIR + '/cifar-10-batches-py/test_batch') 100 | test_images = test['data'].astype(np.float32) 101 | # geosada 170713 102 | #test_images = (test_images - 127.5) / 255. 103 | # -> [0,1] 104 | test_images = test_images / 255. 105 | test_labels = np.asarray(test['labels'], dtype=np.int64) 106 | 107 | # geosada 170713 for generative model 108 | """ 109 | print("Apply ZCA whitening") 110 | components, mean, train_images = ZCA(train_images) 111 | np.save('{}/components'.format(DATA_DIR), components) 112 | np.save('{}/mean'.format(DATA_DIR), mean) 113 | test_images = np.dot(test_images - mean, components.T) 114 | """ 115 | 116 | train_images = train_images.reshape( 117 | (NUM_EXAMPLES_TRAIN, 3, 32, 32)).transpose((0, 2, 3, 1)).reshape((NUM_EXAMPLES_TRAIN, -1)) 118 | test_images = test_images.reshape( 119 | (NUM_EXAMPLES_TEST, 3, 32, 32)).transpose((0, 2, 3, 1)).reshape((NUM_EXAMPLES_TEST, -1)) 120 | return (train_images, train_labels), (test_images, test_labels) 121 | 122 | 123 | def prepare_dataset(): 124 | (train_images, train_labels), (test_images, test_labels) = load_cifar10() 125 | 126 | dirpath = os.path.join(DATA_DIR, 'seed' + str(DATASET_SEED)) 127 | if not os.path.exists(dirpath): 128 | os.makedirs(dirpath) 129 | 130 | rng = np.random.RandomState(DATASET_SEED) 131 | rand_ix = rng.permutation(NUM_EXAMPLES_TRAIN) 132 | _train_images, _train_labels = train_images[rand_ix], train_labels[rand_ix] 133 | 134 | examples_per_class = int(N_LABELED / 10) 135 | labeled_train_images = np.zeros((N_LABELED, 3072), dtype=np.float32) 136 | labeled_train_labels = np.zeros((N_LABELED), dtype=np.int64) 137 | for i in xrange(10): 138 | ind = np.where(_train_labels == i)[0] 139 | labeled_train_images[i * examples_per_class:(i + 1) * examples_per_class] \ 140 | = _train_images[ind[0:examples_per_class]] 141 | labeled_train_labels[i * examples_per_class:(i + 1) * examples_per_class] \ 142 | = _train_labels[ind[0:examples_per_class]] 143 | _train_images = np.delete(_train_images, 144 | ind[0:examples_per_class], 0) 145 | _train_labels = np.delete(_train_labels, 146 | ind[0:examples_per_class]) 147 | 148 | rand_ix_labeled = rng.permutation(N_LABELED) 149 | labeled_train_images, labeled_train_labels = \ 150 | labeled_train_images[rand_ix_labeled], labeled_train_labels[rand_ix_labeled] 151 | 152 | convert_images_and_labels(labeled_train_images, 153 | labeled_train_labels, 154 | os.path.join(dirpath, 'labeled_train.tfrecords')) 155 | convert_images_and_labels(train_images, train_labels, 156 | os.path.join(dirpath, 'unlabeled_train.tfrecords')) 157 | convert_images_and_labels(test_images, 158 | test_labels, 159 | os.path.join(dirpath, 'test.tfrecords')) 160 | 161 | # Construct dataset for validation 162 | train_images_valid, train_labels_valid = \ 163 | labeled_train_images[FLAGS.num_valid_examples:], labeled_train_labels[FLAGS.num_valid_examples:] 164 | test_images_valid, test_labels_valid = \ 165 | labeled_train_images[:FLAGS.num_valid_examples], labeled_train_labels[:FLAGS.num_valid_examples] 166 | unlabeled_train_images_valid = np.concatenate( 167 | (train_images_valid, _train_images), axis=0) 168 | unlabeled_train_labels_valid = np.concatenate( 169 | (train_labels_valid, _train_labels), axis=0) 170 | convert_images_and_labels(train_images_valid, 171 | train_labels_valid, 172 | os.path.join(dirpath, 'labeled_train_val.tfrecords')) 173 | convert_images_and_labels(unlabeled_train_images_valid, 174 | unlabeled_train_labels_valid, 175 | os.path.join(dirpath, 'unlabeled_train_val.tfrecords')) 176 | convert_images_and_labels(test_images_valid, 177 | test_labels_valid, 178 | os.path.join(dirpath, 'test_val.tfrecords')) 179 | 180 | 181 | def inputs(batch_size=100, 182 | train=True, validation=False, 183 | shuffle=True, num_epochs=None): 184 | if validation: 185 | if train: 186 | filenames = ['labeled_train_val.tfrecords'] 187 | num_examples = N_LABELED - FLAGS.num_valid_examples 188 | else: 189 | filenames = ['test_val.tfrecords'] 190 | num_examples = FLAGS.num_valid_examples 191 | else: 192 | if train: 193 | filenames = ['labeled_train.tfrecords'] 194 | num_examples = N_LABELED 195 | else: 196 | filenames = ['test.tfrecords'] 197 | num_examples = NUM_EXAMPLES_TEST 198 | 199 | # geosada 170701 200 | filenames = [os.path.join('seed' + str(DATASET_SEED), filename) for filename in filenames] 201 | #filenames = ['C:/Users/fx29351/Python/data/CIFAR10/seed' + str(DATASET_SEED) + '/' + filename for filename in filenames] 202 | 203 | filename_queue = generate_filename_queue(filenames, DATA_DIR, num_epochs) 204 | image, label = read(filename_queue) 205 | image = transform(tf.cast(image, tf.float32)) if train else image 206 | return generate_batch([image, label], num_examples, batch_size, shuffle) 207 | 208 | 209 | def unlabeled_inputs(batch_size=100, 210 | validation=False, 211 | shuffle=True): 212 | if validation: 213 | filenames = ['unlabeled_train_val.tfrecords'] 214 | num_examples = NUM_EXAMPLES_TRAIN - FLAGS.num_valid_examples 215 | else: 216 | filenames = ['unlabeled_train.tfrecords'] 217 | num_examples = NUM_EXAMPLES_TRAIN 218 | 219 | # geosada 170701 220 | filenames = [os.path.join('seed' + str(DATASET_SEED), filename) for filename in filenames] 221 | #filenames = ['C:/Users/fx29351/Python/data/CIFAR10/seed' + str(DATASET_SEED) + '/' + filename for filename in filenames] 222 | filename_queue = generate_filename_queue(filenames, DATA_DIR) 223 | image, label = read(filename_queue) 224 | image = transform(tf.cast(image, tf.float32)) 225 | return generate_batch([image], num_examples, batch_size, shuffle) 226 | 227 | 228 | def main(argv): 229 | prepare_dataset() 230 | 231 | 232 | if __name__ == "__main__": 233 | tf.app.run() 234 | -------------------------------------------------------------------------------- /glow/util/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os, sys, pickle 3 | import numpy as np 4 | from scipy import linalg 5 | 6 | #FLAGS = tf.app.flags.FLAGS 7 | #tf.app.flags.DEFINE_bool('aug_trans', False, "") 8 | #tf.app.flags.DEFINE_bool('aug_flip', False, "") 9 | 10 | AUG_TRANS = False 11 | AUG_FLIP = False 12 | 13 | def unpickle(file): 14 | fp = open(file, 'rb') 15 | if sys.version_info.major == 2: 16 | data = pickle.load(fp) 17 | elif sys.version_info.major == 3: 18 | data = pickle.load(fp, encoding='latin-1') 19 | fp.close() 20 | return data 21 | 22 | 23 | def ZCA(data, reg=1e-6): 24 | mean = np.mean(data, axis=0) 25 | mdata = data - mean 26 | sigma = np.dot(mdata.T, mdata) / mdata.shape[0] 27 | U, S, V = linalg.svd(sigma) 28 | components = np.dot(np.dot(U, np.diag(1 / np.sqrt(S) + reg)), U.T) 29 | whiten = np.dot(data - mean, components.T) 30 | return components, mean, whiten 31 | 32 | 33 | def _int64_feature(value): 34 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 35 | 36 | 37 | def _bytes_feature(value): 38 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 39 | 40 | 41 | def convert_images_and_labels(images, labels, filepath): 42 | 43 | print('[DEBUG] inputs shape:', images.shape, labels.shape) # (4000, 3072) (4000,) 44 | num_examples = labels.shape[0] 45 | if images.shape[0] != num_examples: 46 | raise ValueError("Images size %d does not match label size %d." % 47 | (images.shape[0], num_examples)) 48 | print('Writing', filepath) 49 | writer = tf.python_io.TFRecordWriter(filepath) 50 | for index in range(num_examples): 51 | image = images[index].tolist() 52 | image_feature = tf.train.Feature(float_list=tf.train.FloatList(value=image)) 53 | #print('[DEBUG] image_feature:', image_feature) # float_list { value: xxx},...} 54 | example = tf.train.Example(features=tf.train.Features(feature={ 55 | 'height': _int64_feature(32), 56 | 'width': _int64_feature(32), 57 | 'depth': _int64_feature(3), 58 | 'label': _int64_feature(int(labels[index])), 59 | 'image': image_feature})) 60 | writer.write(example.SerializeToString()) 61 | writer.close() 62 | 63 | 64 | def read(filename_queue): 65 | reader = tf.TFRecordReader() 66 | print('filename_queue',filename_queue) 67 | _, serialized_example = reader.read(filename_queue) 68 | features = tf.parse_single_example( 69 | serialized_example, 70 | # Defaults are not specified since both keys are required. 71 | features={ 72 | 'image': tf.FixedLenFeature([3072], tf.float32), 73 | 'label': tf.FixedLenFeature([], tf.int64), 74 | }) 75 | 76 | # Convert label from a scalar uint8 tensor to an int32 scalar. 77 | image = features['image'] 78 | image = tf.reshape(image, [32, 32, 3]) 79 | label = tf.one_hot(tf.cast(features['label'], tf.int32), 10) 80 | return image, label 81 | 82 | 83 | def generate_batch( 84 | example, 85 | min_queue_examples, 86 | batch_size, shuffle): 87 | """ 88 | Arg: 89 | list of tensors. 90 | """ 91 | num_preprocess_threads = 1 92 | 93 | if shuffle: 94 | ret = tf.train.shuffle_batch( 95 | example, 96 | batch_size=batch_size, 97 | num_threads=num_preprocess_threads, 98 | capacity=min_queue_examples + 5 * batch_size, 99 | min_after_dequeue=min_queue_examples) 100 | else: 101 | ret = tf.train.batch( 102 | example, 103 | batch_size=batch_size, 104 | num_threads=num_preprocess_threads, 105 | allow_smaller_final_batch=True, 106 | capacity=min_queue_examples + 5 * batch_size) 107 | 108 | return ret 109 | 110 | 111 | def transform(image): 112 | image = tf.reshape(image, [32, 32, 3]) 113 | if AUG_TRANS or AUG_FLIP: 114 | print("augmentation") 115 | if AUG_TRANS: 116 | image = tf.pad(image, [[2, 2], [2, 2], [0, 0]]) 117 | image = tf.random_crop(image, [32, 32, 3]) 118 | if AUG_FLIP: 119 | image = tf.image.random_flip_left_right(image) 120 | return image 121 | 122 | 123 | def generate_filename_queue(filenames, data_dir, num_epochs=None): 124 | print("filenames in queue:", filenames) 125 | for i in range(len(filenames)): 126 | filenames[i] = os.path.join(data_dir, filenames[i]) 127 | return tf.train.string_input_producer(filenames, num_epochs=num_epochs) 128 | 129 | 130 | -------------------------------------------------------------------------------- /glow/util/svhn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | from scipy.io import loadmat 8 | 9 | import numpy as np 10 | from scipy import linalg 11 | import glob 12 | import pickle 13 | 14 | from six.moves import xrange # pylint: disable=redefined-builtin 15 | from six.moves import urllib 16 | 17 | import tensorflow as tf 18 | from dataset_utils import * 19 | 20 | DATA_URL_TRAIN = 'http://ufldl.stanford.edu/housenumbers/train_32x32.mat' 21 | DATA_URL_TEST = 'http://ufldl.stanford.edu/housenumbers/test_32x32.mat' 22 | 23 | N_LABELED = 73257 24 | #N_LABELED = 1000 25 | DATASET_SEED = 1 26 | #DATA_DIR = '/data/img/SVHN__labled_1000' 27 | DATA_DIR = '/data/img/SVHN' 28 | DATA_DIR = 'D:/data/img/SVHN' 29 | 30 | FLAGS = tf.app.flags.FLAGS 31 | #tf.app.flags.DEFINE_string('data_dir', '/tmp/svhn', "") 32 | #tf.app.flags.DEFINE_integer('num_labeled_examples', 1000, "The number of labeled examples") 33 | tf.app.flags.DEFINE_integer('num_valid_examples', 1000, "The number of validation examples") 34 | #tf.app.flags.DEFINE_integer('dataset_seed', 1, "dataset seed") 35 | 36 | NUM_EXAMPLES_TRAIN = 73257 37 | NUM_EXAMPLES_TEST = 26032 38 | 39 | 40 | def maybe_download_and_extract(): 41 | if not os.path.exists(DATA_DIR): 42 | os.makedirs(DATA_DIR) 43 | filepath_train_mat = os.path.join(DATA_DIR, 'train_32x32.mat') 44 | filepath_test_mat = os.path.join(DATA_DIR, 'test_32x32.mat') 45 | print(filepath_train_mat) 46 | print(filepath_test_mat) 47 | if not os.path.exists(filepath_train_mat) or not os.path.exists(filepath_test_mat): 48 | def _progress(count, block_size, total_size): 49 | sys.stdout.write('\r>> Downloading %.1f%%' % (float(count * block_size) / float(total_size) * 100.0)) 50 | sys.stdout.flush() 51 | 52 | urllib.request.urlretrieve(DATA_URL_TRAIN, filepath_train_mat, _progress) 53 | urllib.request.urlretrieve(DATA_URL_TEST, filepath_test_mat, _progress) 54 | 55 | # Training set 56 | print("... loading training data") 57 | train_data = loadmat(DATA_DIR + '/train_32x32.mat') 58 | 59 | # geosada 170717 60 | #train_x = (-127.5 + train_data['X']) / 255. 61 | train_x = (train_data['X']) / 255. 62 | train_x = train_x.transpose((3, 0, 1, 2)) 63 | train_x = train_x.reshape([train_x.shape[0], -1]) 64 | train_y = train_data['y'].flatten().astype(np.int32) 65 | train_y[train_y == 10] = 0 66 | 67 | # Test set 68 | print("... loading testing data") 69 | test_data = loadmat(DATA_DIR + '/test_32x32.mat') 70 | # geosada 170717 71 | #test_x = (-127.5 + test_data['X']) / 255. 72 | test_x = (test_data['X']) / 255. 73 | test_x = test_x.transpose((3, 0, 1, 2)) 74 | test_x = test_x.reshape((test_x.shape[0], -1)) 75 | test_y = test_data['y'].flatten().astype(np.int32) 76 | test_y[test_y == 10] = 0 77 | 78 | print("... saving npy as cache") 79 | np.save('{}/train_images'.format(DATA_DIR), train_x) 80 | np.save('{}/train_labels'.format(DATA_DIR), train_y) 81 | np.save('{}/test_images'.format(DATA_DIR), test_x) 82 | np.save('{}/test_labels'.format(DATA_DIR), test_y) 83 | 84 | 85 | def load_svhn(_use_cache=False): 86 | 87 | if _use_cache: 88 | print("... loading data from npy cache") 89 | else: 90 | maybe_download_and_extract() 91 | 92 | # 93 | # returned shape: 94 | # images: (n, img_size) 95 | # labels: (n,) 96 | # 97 | train_images = np.load('{}/train_images.npy'.format(DATA_DIR)).astype(np.float32) 98 | train_labels = np.load('{}/train_labels.npy'.format(DATA_DIR)).astype(np.uint8) 99 | test_images = np.load('{}/test_images.npy'.format(DATA_DIR)).astype(np.float32) 100 | test_labels = np.load('{}/test_labels.npy'.format(DATA_DIR)).astype(np.uint8) 101 | return (train_images, train_labels), (test_images, test_labels) 102 | 103 | 104 | def prepare_dataset(): 105 | (train_images, train_labels), (test_images, test_labels) = load_svhn() 106 | dirpath = os.path.join(DATA_DIR, 'seed' + str(DATASET_SEED)) 107 | if not os.path.exists(dirpath): 108 | os.makedirs(dirpath) 109 | 110 | rng = np.random.RandomState(DATASET_SEED) 111 | rand_ix = rng.permutation(NUM_EXAMPLES_TRAIN) 112 | _train_images, _train_labels = train_images[rand_ix], train_labels[rand_ix] 113 | 114 | 115 | if N_LABELED == NUM_EXAMPLES_TRAIN: 116 | print('>>> superviesed <<<') 117 | labeled_ind = np.arange(N_LABELED) 118 | labeled_train_images, labeled_train_labels = _train_images[labeled_ind], _train_labels[labeled_ind] 119 | _train_images = np.delete(_train_images, labeled_ind, 0) 120 | _train_labels = np.delete(_train_labels, labeled_ind, 0) 121 | 122 | else: 123 | print('>>> semi-superviesed <<<') 124 | examples_per_class = int(N_LABELED / 10) 125 | labeled_train_images = np.zeros((N_LABELED, 3072), dtype=np.float32) 126 | labeled_train_labels = np.zeros((N_LABELED), dtype=np.int64) 127 | for i in xrange(10): 128 | ind = np.where(_train_labels == i)[0] 129 | labeled_train_images[i * examples_per_class:(i + 1) * examples_per_class] \ 130 | = _train_images[ind[0:examples_per_class]] 131 | labeled_train_labels[i * examples_per_class:(i + 1) * examples_per_class] \ 132 | = _train_labels[ind[0:examples_per_class]] 133 | _train_images = np.delete(_train_images, 134 | ind[0:examples_per_class], 0) 135 | _train_labels = np.delete(_train_labels, 136 | ind[0:examples_per_class]) 137 | 138 | rand_ix_labeled = rng.permutation(N_LABELED) 139 | labeled_train_images, labeled_train_labels = labeled_train_images[rand_ix_labeled], labeled_train_labels[rand_ix_labeled] 140 | 141 | #print(labeled_train_images.shape, labeled_train_labels.shape) 142 | #print(train_images.shape, train_labels.shape) 143 | #print(test_images.shape, test_labels.shape) 144 | #print(labeled_train_labels) 145 | 146 | convert_images_and_labels(labeled_train_images, 147 | labeled_train_labels, 148 | os.path.join(dirpath, 'labeled_train.tfrecords')) 149 | convert_images_and_labels(train_images, train_labels, 150 | os.path.join(dirpath, 'unlabeled_train.tfrecords')) 151 | convert_images_and_labels(test_images, 152 | test_labels, 153 | os.path.join(dirpath, 'test.tfrecords')) 154 | 155 | # Construct dataset for validation 156 | train_images_valid, train_labels_valid = labeled_train_images, labeled_train_labels 157 | test_images_valid, test_labels_valid = \ 158 | _train_images[:FLAGS.num_valid_examples], _train_labels[:FLAGS.num_valid_examples] 159 | unlabeled_train_images_valid = np.concatenate( 160 | (train_images_valid, _train_images[FLAGS.num_valid_examples:]), axis=0) 161 | unlabeled_train_labels_valid = np.concatenate( 162 | (train_labels_valid, _train_labels[FLAGS.num_valid_examples:]), axis=0) 163 | convert_images_and_labels(train_images_valid, 164 | train_labels_valid, 165 | os.path.join(dirpath, 'labeled_train_val.tfrecords')) 166 | convert_images_and_labels(unlabeled_train_images_valid, 167 | unlabeled_train_labels_valid, 168 | os.path.join(dirpath, 'unlabeled_train_val.tfrecords')) 169 | convert_images_and_labels(test_images_valid, 170 | test_labels_valid, 171 | os.path.join(dirpath, 'test_val.tfrecords')) 172 | 173 | 174 | def inputs(batch_size=100, 175 | train=True, validation=False, 176 | shuffle=True, num_epochs=None): 177 | if validation: 178 | if train: 179 | filenames = ['labeled_train_val.tfrecords'] 180 | num_examples = N_LABELED 181 | else: 182 | filenames = ['test_val.tfrecords'] 183 | num_examples = FLAGS.num_valid_examples 184 | else: 185 | if train: 186 | filenames = ['labeled_train.tfrecords'] 187 | num_examples = N_LABELED 188 | else: 189 | filenames = ['test.tfrecords'] 190 | num_examples = NUM_EXAMPLES_TEST 191 | 192 | filenames = [os.path.join('seed' + str(DATASET_SEED), filename) for filename in filenames] 193 | filename_queue = generate_filename_queue(filenames, DATA_DIR, num_epochs) 194 | image, label = read(filename_queue) 195 | image = transform(tf.cast(image, tf.float32)) if train else image 196 | return generate_batch([image, label], num_examples, batch_size, shuffle) 197 | 198 | 199 | def unlabeled_inputs(batch_size=100, 200 | validation=False, 201 | shuffle=True): 202 | if validation: 203 | filenames = ['unlabeled_train_val.tfrecords'] 204 | num_examples = NUM_EXAMPLES_TRAIN - FLAGS.num_valid_examples 205 | else: 206 | filenames = ['unlabeled_train.tfrecords'] 207 | num_examples = NUM_EXAMPLES_TRAIN 208 | 209 | filenames = [os.path.join('seed' + str(DATASET_SEED), filename) for filename in filenames] 210 | filename_queue = generate_filename_queue(filenames, data_dir=DATA_DIR) 211 | image, label = read(filename_queue) 212 | image = transform(tf.cast(image, tf.float32)) 213 | return generate_batch([image], num_examples, batch_size, shuffle) 214 | 215 | 216 | def main(argv): 217 | prepare_dataset() 218 | 219 | 220 | if __name__ == "__main__": 221 | tf.app.run() 222 | -------------------------------------------------------------------------------- /glow/util/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions used in notebooks""" 2 | from collections import defaultdict 3 | from typing import Optional, Callable, List 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import tensorflow as tf 8 | from tqdm import tqdm 9 | 10 | 11 | def numpy_array_to_dataset( 12 | array: np.array, 13 | buffer_size: int = 512, 14 | batch_size: int = 100, 15 | num_parallel_batches: int = 16, 16 | preprocess_fn: Optional[Callable] = None, 17 | ) -> tf.data.Dataset: 18 | """Convert numpy array to tf.data.Dataset""" 19 | dataset = tf.data.Dataset.from_tensor_slices(array.astype(np.float32)) 20 | dataset = dataset.apply( 21 | tf.contrib.data.shuffle_and_repeat(buffer_size=buffer_size, count=-1) 22 | ) 23 | if preprocess_fn is None: 24 | dataset = dataset.batch(batch_size=batch_size) 25 | else: 26 | dataset = dataset.apply( 27 | tf.contrib.data.map_and_batch( 28 | map_func=preprocess_fn, 29 | batch_size=batch_size, 30 | num_parallel_batches=num_parallel_batches, 31 | drop_remainder=True, 32 | ) 33 | ) 34 | dataset = dataset.prefetch(4) 35 | return dataset 36 | 37 | 38 | def create_tfrecord_dataset_iterator( 39 | tfrecord_paths: List[str], 40 | image_size: int = 48, 41 | #batch_size: int = 16, 42 | batch_size: int = 4, 43 | buffer_size: int = 512, 44 | #num_parallel_batches: int = 16, # org 45 | num_parallel_batches: int = 1, 46 | ) -> tf.Tensor: 47 | """ 48 | Celeba dataset loader 49 | Args: 50 | tfrecord_paths: path to the tfrecords with Celeba images 51 | image_size: a resize size of the input images 52 | batch_size: batch size 53 | buffer_size: shuffle buffer size 54 | num_parallel_batches: number of parallel calls when reading 55 | and preparing dataset 56 | 57 | Returns: 58 | image a tensor iterator of shape [batch_size, image_size, image_size, 3] 59 | """ 60 | dataset = tf.data.TFRecordDataset(tfrecord_paths) 61 | 62 | def preprocess_images(example): 63 | features = tf.parse_single_example( 64 | example, features={"image_raw": tf.FixedLenFeature([], tf.string)} 65 | ) 66 | 67 | image = tf.decode_raw(features["image_raw"], tf.uint8) 68 | image.set_shape([218 * 178 * 3]) # 218, 178 69 | image = tf.cast(image, tf.float32) 70 | image = tf.reshape(image, [218, 178, 3]) 71 | image = image[40:188, 15:163, :] 72 | image = tf.reshape(image, [148, 148, 3]) 73 | return image 74 | 75 | if buffer_size > 0: 76 | dataset = dataset.apply( 77 | tf.contrib.data.shuffle_and_repeat(buffer_size=buffer_size, 78 | count=-1) 79 | ) 80 | 81 | dataset = dataset.apply( 82 | tf.contrib.data.map_and_batch( 83 | map_func=preprocess_images, 84 | batch_size=batch_size, 85 | num_parallel_batches=num_parallel_batches, 86 | drop_remainder=True, 87 | ) 88 | ) 89 | dataset = dataset.prefetch(4) 90 | images = dataset.make_one_shot_iterator().get_next() 91 | 92 | x_in = tf.reshape(images, [batch_size, 148, 148, 3]) 93 | x_in = tf.image.resize_images( 94 | x_in, [image_size, image_size], method=0, align_corners=False 95 | ) 96 | return x_in / 255.0 97 | 98 | 99 | _epsilon = 1e-5 100 | 101 | 102 | def safe_log(x: tf.Tensor) -> tf.Tensor: 103 | return tf.log(tf.maximum(x, _epsilon)) 104 | 105 | 106 | class Metrics: 107 | def __init__(self, step, metrics_tensors): 108 | self.metrics = defaultdict(list) 109 | self.step = step 110 | self.metrics_tensors = metrics_tensors 111 | 112 | def check_step(self, i): 113 | return (i + 1) % self.step == 0 114 | 115 | def append(self, results): 116 | for k, t in self.metrics_tensors.items(): 117 | self.metrics[k].append(results[k]) 118 | print(k, results[k]) 119 | 120 | def get(self): 121 | return self.metrics_tensors 122 | 123 | @property 124 | def num_metrics(self): 125 | return len(self.metrics) 126 | 127 | 128 | class PlotMetricsHook: 129 | def __init__(self, metrics: Metrics, step=1000, figsize=(15, 3), 130 | skip_steps=5): 131 | self.metrics = metrics 132 | self.step = step 133 | self.figsize = figsize 134 | self.skip_steps = skip_steps 135 | 136 | def check_step(self, i): 137 | return (i + 1) % self.step == 0 138 | 139 | def run(self): 140 | plt.figure(figsize=self.figsize) 141 | 142 | for k, (m, values) in enumerate(self.metrics.metrics.items()): 143 | plt.subplot(1, self.metrics.num_metrics, k + 1) 144 | plt.title(m) 145 | vals = values[self.skip_steps:] 146 | plt.plot(vals) 147 | vals = np.array(vals) 148 | if len(vals) > 0: 149 | plt.ylim([vals.min(), vals.max()]) 150 | plt.show() 151 | 152 | 153 | def trainer(sess, num_steps, train_op, feed_dict_fn, metrics, hooks): 154 | for i in tqdm(range(num_steps), leave=False): 155 | fetches = {"train_op": train_op} 156 | 157 | for metric in metrics: 158 | if metric.check_step(i): 159 | fetches.update(**metric.get()) 160 | 161 | results = sess.run(fetches=fetches, feed_dict=feed_dict_fn()) 162 | 163 | for metric in metrics: 164 | if metric.check_step(i): 165 | metric.append(results) 166 | 167 | continue 168 | for hook in hooks: 169 | if hook.check_step(i): 170 | hook.run() 171 | 172 | 173 | def plot_4x4_grid( 174 | images: np.ndarray, 175 | shape: tuple = (28, 28), 176 | cmap="gray", 177 | figsize=(4, 4), 178 | filename='figure_4x4.png' 179 | ) -> None: 180 | """ 181 | Plot multiple images in subplot grid. 182 | :param images: tensor with MNIST images with shape [16, *shape] 183 | :param shape: shape of the images 184 | """ 185 | assert images.shape[0] >= 16 186 | dist_samples_np = images[:16, ...].reshape([4, 4, *shape]) 187 | 188 | plt.figure(figsize=figsize) 189 | for i in range(4): 190 | for j in range(4): 191 | plt.subplot(4, 4, i * 4 + j + 1) 192 | plt.imshow(dist_samples_np[i, j], cmap=cmap) 193 | plt.xticks([]) 194 | plt.yticks([]) 195 | plt.subplots_adjust(hspace=0.05, wspace=0.05) 196 | plt.savefig(filename) 197 | 198 | 199 | def plot_grid(images: tf.Tensor, filename='figures.png') -> tf.Tensor: 200 | """ 201 | Plot grid of images using tf.contrib.gan.eval.image_grid 202 | Args: 203 | images: a tensor with batch of images of shape 204 | [batch_size, size, size, 3] 205 | 206 | Returns: 207 | a grid image 208 | """ 209 | batch_size, image_size = images.shape.as_list()[:2] 210 | 211 | grid_image = tf.contrib.gan.eval.image_grid( 212 | images, 213 | grid_shape=[4, batch_size // 4], 214 | image_shape=(image_size, image_size), 215 | num_channels=3 216 | ) 217 | 218 | return grid_image[0] 219 | -------------------------------------------------------------------------------- /lvat/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /lvat/cifar10.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | import tarfile 8 | 9 | import numpy as np 10 | from scipy import linalg 11 | import glob 12 | import pickle 13 | 14 | from six.moves import xrange # pylint: disable=redefined-builtin 15 | from six.moves import urllib 16 | 17 | import tensorflow as tf 18 | 19 | from dataset_utils import * 20 | 21 | DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' 22 | 23 | IMAGE_SIZE = 32 24 | 25 | NUM_CLASSES = 10 26 | NUM_EXAMPLES_TRAIN = 50000 27 | NUM_EXAMPLES_TEST = 10000 28 | 29 | def load_cifar10(): 30 | """Download and extract the tarball from Alex's website.""" 31 | dest_directory = FLAGS.data__dir 32 | if not os.path.exists(dest_directory): 33 | os.makedirs(dest_directory) 34 | filename = DATA_URL.split('/')[-1] 35 | filepath = os.path.join(dest_directory, filename) 36 | if not os.path.exists(filepath): 37 | def _progress(count, block_size, total_size): 38 | sys.stdout.write('\r>> Downloading %s %.1f%%' % 39 | (filename, float(count * block_size) / 40 | float(total_size) * 100.0)) 41 | sys.stdout.flush() 42 | 43 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 44 | print() 45 | statinfo = os.stat(filepath) 46 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 47 | tarfile.open(filepath, 'r:gz').extractall(dest_directory) 48 | 49 | # Training set 50 | print("Loading training data...") 51 | train_images = np.zeros((NUM_EXAMPLES_TRAIN, 3 * 32 * 32), dtype=np.float32) 52 | train_labels = [] 53 | for i, data_fn in enumerate( 54 | sorted(glob.glob(FLAGS.data__dir + '/cifar-10-batches-py/data_batch*'))): 55 | batch = unpickle(data_fn) 56 | train_images[i * 10000:(i + 1) * 10000] = batch['data'] 57 | train_labels.extend(batch['labels']) 58 | train_images = (train_images - 127.5) / 255. 59 | train_labels = np.asarray(train_labels, dtype=np.int64) 60 | 61 | rand_ix = np.random.permutation(NUM_EXAMPLES_TRAIN) 62 | train_images = train_images[rand_ix] 63 | train_labels = train_labels[rand_ix] 64 | 65 | print("Loading test data...") 66 | test = unpickle(FLAGS.data__dir + '/cifar-10-batches-py/test_batch') 67 | test_images = test['data'].astype(np.float32) 68 | test_images = (test_images - 127.5) / 255. 69 | test_labels = np.asarray(test['labels'], dtype=np.int64) 70 | 71 | print("Apply ZCA whitening") 72 | components, mean, train_images = ZCA(train_images) 73 | np.save('{}/components'.format(FLAGS.data__dir), components) 74 | np.save('{}/mean'.format(FLAGS.data__dir), mean) 75 | test_images = np.dot(test_images - mean, components.T) 76 | 77 | train_images = train_images.reshape( 78 | (NUM_EXAMPLES_TRAIN, 3, 32, 32)).transpose((0, 2, 3, 1)).reshape((NUM_EXAMPLES_TRAIN, -1)) 79 | test_images = test_images.reshape( 80 | (NUM_EXAMPLES_TEST, 3, 32, 32)).transpose((0, 2, 3, 1)).reshape((NUM_EXAMPLES_TEST, -1)) 81 | return (train_images, train_labels), (test_images, test_labels) 82 | 83 | 84 | def prepare_dataset(): 85 | (train_images, train_labels), (test_images, test_labels) = load_cifar10() 86 | dirpath = os.path.join(FLAGS.data__dir, 'seed' + str(FLAGS.dataset_seed)) 87 | if not os.path.exists(dirpath): 88 | os.makedirs(dirpath) 89 | 90 | rng = np.random.RandomState(FLAGS.dataset_seed) 91 | rand_ix = rng.permutation(NUM_EXAMPLES_TRAIN) 92 | _train_images, _train_labels = train_images[rand_ix], train_labels[rand_ix] 93 | 94 | examples_per_class = int(FLAGS.num_labeled_examples / 10) 95 | labeled_train_images = np.zeros((FLAGS.num_labeled_examples, 3072), dtype=np.float32) 96 | labeled_train_labels = np.zeros((FLAGS.num_labeled_examples), dtype=np.int64) 97 | for i in xrange(10): 98 | ind = np.where(_train_labels == i)[0] 99 | labeled_train_images[i * examples_per_class:(i + 1) * examples_per_class] \ 100 | = _train_images[ind[0:examples_per_class]] 101 | labeled_train_labels[i * examples_per_class:(i + 1) * examples_per_class] \ 102 | = _train_labels[ind[0:examples_per_class]] 103 | _train_images = np.delete(_train_images, 104 | ind[0:examples_per_class], 0) 105 | _train_labels = np.delete(_train_labels, 106 | ind[0:examples_per_class]) 107 | 108 | rand_ix_labeled = rng.permutation(FLAGS.num_labeled_examples) 109 | labeled_train_images, labeled_train_labels = \ 110 | labeled_train_images[rand_ix_labeled], labeled_train_labels[rand_ix_labeled] 111 | 112 | convert_images_and_labels(labeled_train_images, 113 | labeled_train_labels, 114 | os.path.join(dirpath, 'labeled_train.tfrecords')) 115 | convert_images_and_labels(train_images, train_labels, 116 | os.path.join(dirpath, 'unlabeled_train.tfrecords')) 117 | convert_images_and_labels(test_images, 118 | test_labels, 119 | os.path.join(dirpath, 'test.tfrecords')) 120 | 121 | # Construct dataset for validation 122 | train_images_valid, train_labels_valid = \ 123 | labeled_train_images[FLAGS.num_valid_examples:], labeled_train_labels[FLAGS.num_valid_examples:] 124 | test_images_valid, test_labels_valid = \ 125 | labeled_train_images[:FLAGS.num_valid_examples], labeled_train_labels[:FLAGS.num_valid_examples] 126 | unlabeled_train_images_valid = np.concatenate( 127 | (train_images_valid, _train_images), axis=0) 128 | unlabeled_train_labels_valid = np.concatenate( 129 | (train_labels_valid, _train_labels), axis=0) 130 | convert_images_and_labels(train_images_valid, 131 | train_labels_valid, 132 | os.path.join(dirpath, 'labeled_train_val.tfrecords')) 133 | convert_images_and_labels(unlabeled_train_images_valid, 134 | unlabeled_train_labels_valid, 135 | os.path.join(dirpath, 'unlabeled_train_val.tfrecords')) 136 | convert_images_and_labels(test_images_valid, 137 | test_labels_valid, 138 | os.path.join(dirpath, 'test_val.tfrecords')) 139 | 140 | 141 | def inputs(batch_size=100, 142 | train=True, validation=False, 143 | shuffle=True, num_epochs=None): 144 | if validation: 145 | if train: 146 | filenames = ['labeled_train_val.tfrecords'] 147 | num_examples = FLAGS.num_labeled_examples - FLAGS.num_valid_examples 148 | else: 149 | filenames = ['test_val.tfrecords'] 150 | num_examples = FLAGS.num_valid_examples 151 | else: 152 | if train: 153 | filenames = ['labeled_train.tfrecords'] 154 | num_examples = FLAGS.num_labeled_examples 155 | else: 156 | filenames = ['test.tfrecords'] 157 | num_examples = NUM_EXAMPLES_TEST 158 | 159 | filenames = [os.path.join('seed' + str(FLAGS.dataset_seed), filename) for filename in filenames] 160 | 161 | filename_queue = generate_filename_queue(filenames, FLAGS.data__dir, num_epochs) 162 | image, label = read(filename_queue) 163 | image = transform(tf.cast(image, tf.float32)) if train else image 164 | return generate_batch([image, label], num_examples, batch_size, shuffle) 165 | 166 | 167 | def unlabeled_inputs(batch_size=100, 168 | validation=False, 169 | shuffle=True): 170 | if validation: 171 | filenames = ['unlabeled_train_val.tfrecords'] 172 | num_examples = NUM_EXAMPLES_TRAIN - FLAGS.num_valid_examples 173 | else: 174 | filenames = ['unlabeled_train.tfrecords'] 175 | num_examples = NUM_EXAMPLES_TRAIN 176 | 177 | filenames = [os.path.join('seed' + str(FLAGS.dataset_seed), filename) for filename in filenames] 178 | filename_queue = generate_filename_queue(filenames, FLAGS.data__dir) 179 | image, label = read(filename_queue) 180 | image = transform(tf.cast(image, tf.float32)) 181 | return generate_batch([image], num_examples, batch_size, shuffle) 182 | 183 | 184 | def main(argv): 185 | prepare_dataset() 186 | 187 | 188 | if __name__ == "__main__": 189 | FLAGS = tf.app.flags.FLAGS 190 | #tf.app.flags.DEFINE_string('data__dir', '/data/img/CIFAR10__labled_4000', 191 | #tf.app.flags.DEFINE_string('data__dir', '/data/img/CIFAR10_w_ZCA__labled_4000', 192 | tf.app.flags.DEFINE_string('data__dir', '/data/img/CIFAR10_w_ZCA', 193 | 'where to store the dataset') 194 | #tf.app.flags.DEFINE_integer('num_labeled_examples', 4000, "The number of labeled examples") 195 | tf.app.flags.DEFINE_integer('num_labeled_examples', 50000, "The number of labeled examples") 196 | tf.app.flags.DEFINE_integer('num_valid_examples', 1000, "The number of validation examples") 197 | tf.app.flags.DEFINE_integer('dataset_seed', 1, "dataset seed") 198 | 199 | tf.app.run() 200 | -------------------------------------------------------------------------------- /lvat/cnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy 3 | import sys, os 4 | import layers_vat as L 5 | 6 | FLAGS = tf.app.flags.FLAGS 7 | tf.app.flags.DEFINE_float('keep_prob_hidden', 0.5, "dropout rate") 8 | tf.app.flags.DEFINE_float('lrelu_a', 0.1, "lrelu slope") 9 | tf.app.flags.DEFINE_boolean('top_bn', False, "") 10 | 11 | 12 | def logit(x, is_training=True, update_batch_stats=True, stochastic=True, seed=1234): 13 | h = x 14 | 15 | rng = numpy.random.RandomState(seed) 16 | 17 | h = L.conv(h, ksize=3, stride=1, f_in=3, f_out=128, seed=rng.randint(123456), name='c1') 18 | h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b1'), FLAGS.lrelu_a) 19 | h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=128, seed=rng.randint(123456), name='c2') 20 | h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b2'), FLAGS.lrelu_a) 21 | h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=128, seed=rng.randint(123456), name='c3') 22 | h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b3'), FLAGS.lrelu_a) 23 | 24 | h = L.max_pool(h, ksize=2, stride=2) 25 | h = tf.nn.dropout(h, keep_prob=FLAGS.keep_prob_hidden, seed=rng.randint(123456)) if stochastic else h 26 | 27 | h = L.conv(h, ksize=3, stride=1, f_in=128, f_out=256, seed=rng.randint(123456), name='c4') 28 | h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b4'), FLAGS.lrelu_a) 29 | h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=256, seed=rng.randint(123456), name='c5') 30 | h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b5'), FLAGS.lrelu_a) 31 | h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=256, seed=rng.randint(123456), name='c6') 32 | h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b6'), FLAGS.lrelu_a) 33 | 34 | h = L.max_pool(h, ksize=2, stride=2) 35 | h = tf.nn.dropout(h, keep_prob=FLAGS.keep_prob_hidden, seed=rng.randint(123456)) if stochastic else h 36 | 37 | h = L.conv(h, ksize=3, stride=1, f_in=256, f_out=512, seed=rng.randint(123456), padding="VALID", name='c7') 38 | h = L.lrelu(L.bn(h, 512, is_training=is_training, update_batch_stats=update_batch_stats, name='b7'), FLAGS.lrelu_a) 39 | h = L.conv(h, ksize=1, stride=1, f_in=512, f_out=256, seed=rng.randint(123456), name='c8') 40 | h = L.lrelu(L.bn(h, 256, is_training=is_training, update_batch_stats=update_batch_stats, name='b8'), FLAGS.lrelu_a) 41 | h = L.conv(h, ksize=1, stride=1, f_in=256, f_out=128, seed=rng.randint(123456), name='c9') 42 | h = L.lrelu(L.bn(h, 128, is_training=is_training, update_batch_stats=update_batch_stats, name='b9'), FLAGS.lrelu_a) 43 | 44 | h = tf.reduce_mean(h, reduction_indices=[1, 2]) # Global average pooling 45 | h = L.fc(h, 128, FLAGS.n_class, seed=rng.randint(123456), name='fc') 46 | 47 | if FLAGS.top_bn: 48 | h = L.bn(h, FLAGS.n_class, is_training=is_training, 49 | update_batch_stats=update_batch_stats, name='bfc') 50 | 51 | return h 52 | -------------------------------------------------------------------------------- /lvat/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os, sys, pickle 3 | import numpy as np 4 | from scipy import linalg 5 | 6 | FLAGS = tf.app.flags.FLAGS 7 | #tf.app.flags.DEFINE_bool('is_aug_trans', False, "") 8 | #tf.app.flags.DEFINE_bool('is_aug_flip', False, "") 9 | 10 | def unpickle(file): 11 | fp = open(file, 'rb') 12 | if sys.version_info.major == 2: 13 | data = pickle.load(fp) 14 | elif sys.version_info.major == 3: 15 | data = pickle.load(fp, encoding='latin-1') 16 | fp.close() 17 | return data 18 | 19 | 20 | def ZCA(data, reg=1e-6): 21 | mean = np.mean(data, axis=0) 22 | mdata = data - mean 23 | sigma = np.dot(mdata.T, mdata) / mdata.shape[0] 24 | U, S, V = linalg.svd(sigma) 25 | components = np.dot(np.dot(U, np.diag(1 / np.sqrt(S) + reg)), U.T) 26 | whiten = np.dot(data - mean, components.T) 27 | return components, mean, whiten 28 | 29 | 30 | def _int64_feature(value): 31 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 32 | 33 | 34 | def _bytes_feature(value): 35 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 36 | 37 | 38 | def convert_images_and_labels(images, labels, filepath): 39 | num_examples = labels.shape[0] 40 | if images.shape[0] != num_examples: 41 | raise ValueError("Images size %d does not match label size %d." % 42 | (images.shape[0], num_examples)) 43 | print('Writing', filepath) 44 | writer = tf.python_io.TFRecordWriter(filepath) 45 | for index in range(num_examples): 46 | image = images[index].tolist() 47 | image_feature = tf.train.Feature(float_list=tf.train.FloatList(value=image)) 48 | example = tf.train.Example(features=tf.train.Features(feature={ 49 | 'height': _int64_feature(32), 50 | 'width': _int64_feature(32), 51 | 'depth': _int64_feature(3), 52 | 'label': _int64_feature(int(labels[index])), 53 | 'image': image_feature})) 54 | writer.write(example.SerializeToString()) 55 | writer.close() 56 | 57 | 58 | def read(filename_queue): 59 | reader = tf.TFRecordReader() 60 | _, serialized_example = reader.read(filename_queue) 61 | features = tf.parse_single_example( 62 | serialized_example, 63 | # Defaults are not specified since both keys are required. 64 | features={ 65 | 'image': tf.FixedLenFeature([3072], tf.float32), 66 | 'label': tf.FixedLenFeature([], tf.int64), 67 | }) 68 | 69 | # Convert label from a scalar uint8 tensor to an int32 scalar. 70 | image = features['image'] 71 | image = tf.reshape(image, [32, 32, 3]) 72 | label = tf.one_hot(tf.cast(features['label'], tf.int32), 10) 73 | return image, label 74 | 75 | 76 | def generate_batch( 77 | example, 78 | min_queue_examples, 79 | batch_size, shuffle): 80 | """ 81 | Arg: 82 | list of tensors. 83 | """ 84 | num_preprocess_threads = 1 85 | 86 | if shuffle: 87 | ret = tf.train.shuffle_batch( 88 | example, 89 | batch_size=batch_size, 90 | num_threads=num_preprocess_threads, 91 | capacity=min_queue_examples + 3 * batch_size, 92 | min_after_dequeue=min_queue_examples) 93 | else: 94 | ret = tf.train.batch( 95 | example, 96 | batch_size=batch_size, 97 | num_threads=num_preprocess_threads, 98 | allow_smaller_final_batch=True, 99 | capacity=min_queue_examples + 3 * batch_size) 100 | 101 | return ret 102 | 103 | 104 | def transform(image): 105 | image = tf.reshape(image, [32, 32, 3]) 106 | if FLAGS.is_aug_trans or FLAGS.is_aug_flip: 107 | print("augmentation is enabled") 108 | if FLAGS.is_aug_trans: 109 | print('is_aug_trans:', FLAGS.is_aug_trans) 110 | image = tf.pad(image, [[2, 2], [2, 2], [0, 0]]) 111 | image = tf.random_crop(image, [32, 32, 3]) 112 | if FLAGS.is_aug_flip: 113 | print('is_aug_flip:', FLAGS.is_aug_flip) 114 | image = tf.image.random_flip_left_right(image) 115 | else: 116 | print("augmentation is uneabled") 117 | 118 | return image 119 | 120 | 121 | def generate_filename_queue(filenames, data_dir, num_epochs=None): 122 | print("filenames in queue:", filenames) 123 | for i in range(len(filenames)): 124 | filenames[i] = os.path.join(data_dir, filenames[i]) 125 | return tf.train.string_input_producer(filenames, num_epochs=num_epochs) 126 | 127 | 128 | -------------------------------------------------------------------------------- /lvat/layers_vat.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy 3 | import sys, os 4 | 5 | 6 | FLAGS = tf.app.flags.FLAGS 7 | tf.app.flags.DEFINE_float('bn_stats_decay_factor', 0.99, 8 | "moving average decay factor for stats on batch normalization") 9 | 10 | 11 | def lrelu(x, a=0.1): 12 | if a < 1e-16: 13 | return tf.nn.relu(x) 14 | else: 15 | return tf.maximum(x, a * x) 16 | 17 | 18 | def bn(x, dim, is_training=True, update_batch_stats=True, collections=None, name="bn"): 19 | params_shape = (dim,) 20 | n = tf.to_float(tf.reduce_prod(tf.shape(x)[:-1])) 21 | axis = list(range(int(tf.shape(x).get_shape().as_list()[0]) - 1)) 22 | mean = tf.reduce_mean(x, axis) 23 | var = tf.reduce_mean(tf.pow(x - mean, 2.0), axis) 24 | avg_mean = tf.get_variable( 25 | name=name + "_mean", 26 | shape=params_shape, 27 | initializer=tf.constant_initializer(0.0), 28 | collections=collections, 29 | trainable=False 30 | ) 31 | 32 | avg_var = tf.get_variable( 33 | name=name + "_var", 34 | shape=params_shape, 35 | initializer=tf.constant_initializer(1.0), 36 | collections=collections, 37 | trainable=False 38 | ) 39 | 40 | gamma = tf.get_variable( 41 | name=name + "_gamma", 42 | shape=params_shape, 43 | initializer=tf.constant_initializer(1.0), 44 | collections=collections 45 | ) 46 | 47 | beta = tf.get_variable( 48 | name=name + "_beta", 49 | shape=params_shape, 50 | initializer=tf.constant_initializer(0.0), 51 | collections=collections, 52 | ) 53 | 54 | if is_training: 55 | avg_mean_assign_op = tf.no_op() 56 | avg_var_assign_op = tf.no_op() 57 | if update_batch_stats: 58 | avg_mean_assign_op = tf.assign( 59 | avg_mean, 60 | FLAGS.bn_stats_decay_factor * avg_mean + (1 - FLAGS.bn_stats_decay_factor) * mean) 61 | avg_var_assign_op = tf.assign( 62 | avg_var, 63 | FLAGS.bn_stats_decay_factor * avg_var + (n / (n - 1)) 64 | * (1 - FLAGS.bn_stats_decay_factor) * var) 65 | 66 | with tf.control_dependencies([avg_mean_assign_op, avg_var_assign_op]): 67 | z = (x - mean) / tf.sqrt(1e-6 + var) 68 | else: 69 | z = (x - avg_mean) / tf.sqrt(1e-6 + avg_var) 70 | 71 | return gamma * z + beta 72 | 73 | 74 | def fc(x, dim_in, dim_out, seed=None, name='fc'): 75 | num_units_in = dim_in 76 | num_units_out = dim_out 77 | weights_initializer = tf.contrib.layers.variance_scaling_initializer(seed=seed) 78 | 79 | weights = tf.get_variable(name + '_W', 80 | shape=[num_units_in, num_units_out], 81 | initializer=weights_initializer) 82 | biases = tf.get_variable(name + '_b', 83 | shape=[num_units_out], 84 | initializer=tf.constant_initializer(0.0)) 85 | x = tf.nn.xw_plus_b(x, weights, biases) 86 | return x 87 | 88 | 89 | def conv(x, ksize, stride, f_in, f_out, padding='SAME', use_bias=False, seed=None, name='conv'): 90 | shape = [ksize, ksize, f_in, f_out] 91 | initializer = tf.contrib.layers.variance_scaling_initializer(seed=seed) 92 | weights = tf.get_variable(name + '_W', 93 | shape=shape, 94 | dtype='float', 95 | initializer=initializer) 96 | x = tf.nn.conv2d(x, weights, [1, stride, stride, 1], padding=padding) 97 | 98 | if use_bias: 99 | bias = tf.get_variable(name + '_b', 100 | shape=[f_out], 101 | dtype='float', 102 | initializer=tf.zeros_initializer) 103 | return tf.nn.bias_add(x, bias) 104 | else: 105 | return x 106 | 107 | 108 | def avg_pool(x, ksize=2, stride=2): 109 | return tf.nn.avg_pool(x, 110 | ksize=[1, ksize, ksize, 1], 111 | strides=[1, stride, stride, 1], 112 | padding='SAME') 113 | 114 | 115 | def max_pool(x, ksize=2, stride=2): 116 | return tf.nn.max_pool(x, 117 | ksize=[1, ksize, ksize, 1], 118 | strides=[1, stride, stride, 1], 119 | padding='SAME') 120 | 121 | 122 | def ce_loss(logit, y): 123 | return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=y)) 124 | 125 | 126 | def accuracy(logit, y): 127 | pred = tf.argmax(logit, 1) 128 | true = tf.argmax(y, 1) 129 | return tf.reduce_mean(tf.to_float(tf.equal(pred, true))) 130 | 131 | def pred_and_true(logit, y): 132 | pred = tf.argmax(logit, 1) 133 | true = tf.argmax(y, 1) 134 | return tf.to_float(tf.equal(pred, true)) 135 | 136 | def logsoftmax(x): 137 | xdev = x - tf.reduce_max(x, 1, keep_dims=True) 138 | lsm = xdev - tf.log(tf.reduce_sum(tf.exp(xdev), 1, keep_dims=True)) 139 | return lsm 140 | 141 | 142 | def kl_divergence_with_logit(q_logit, p_logit): 143 | q = tf.nn.softmax(q_logit) 144 | qlogq = tf.reduce_mean(tf.reduce_sum(q * logsoftmax(q_logit), 1)) 145 | qlogp = tf.reduce_mean(tf.reduce_sum(q * logsoftmax(p_logit), 1)) 146 | return qlogq - qlogp 147 | 148 | 149 | def entropy_y_x(logit): 150 | p = tf.nn.softmax(logit) 151 | return -tf.reduce_mean(tf.reduce_sum(p * logsoftmax(logit), 1)) 152 | -------------------------------------------------------------------------------- /lvat/svhn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | from scipy.io import loadmat 8 | 9 | import numpy as np 10 | from scipy import linalg 11 | import glob 12 | import pickle 13 | 14 | from six.moves import xrange # pylint: disable=redefined-builtin 15 | from six.moves import urllib 16 | 17 | import tensorflow as tf 18 | from dataset_utils import * 19 | 20 | DATA_URL_TRAIN = 'http://ufldl.stanford.edu/housenumbers/train_32x32.mat' 21 | DATA_URL_TEST = 'http://ufldl.stanford.edu/housenumbers/test_32x32.mat' 22 | 23 | NUM_EXAMPLES_TRAIN = 73257 24 | NUM_EXAMPLES_TEST = 26032 25 | 26 | 27 | def maybe_download_and_extract(): 28 | if not os.path.exists(FLAGS.data__dir): 29 | os.makedirs(FLAGS.data__dir) 30 | filepath_train_mat = os.path.join(FLAGS.data__dir, 'train_32x32.mat') 31 | filepath_test_mat = os.path.join(FLAGS.data__dir, 'test_32x32.mat') 32 | if not os.path.exists(filepath_train_mat) or not os.path.exists(filepath_test_mat): 33 | def _progress(count, block_size, total_size): 34 | sys.stdout.write('\r>> Downloading %.1f%%' % (float(count * block_size) / float(total_size) * 100.0)) 35 | sys.stdout.flush() 36 | 37 | urllib.request.urlretrieve(DATA_URL_TRAIN, filepath_train_mat, _progress) 38 | urllib.request.urlretrieve(DATA_URL_TEST, filepath_test_mat, _progress) 39 | 40 | # Training set 41 | print("Loading training data...") 42 | print("Preprocessing training data...") 43 | train_data = loadmat(FLAGS.data__dir + '/train_32x32.mat') 44 | train_x = (-127.5 + train_data['X']) / 255. 45 | train_x = train_x.transpose((3, 0, 1, 2)) 46 | train_x = train_x.reshape([train_x.shape[0], -1]) 47 | train_y = train_data['y'].flatten().astype(np.int32) 48 | train_y[train_y == 10] = 0 49 | 50 | # Test set 51 | print("Loading test data...") 52 | test_data = loadmat(FLAGS.data__dir + '/test_32x32.mat') 53 | test_x = (-127.5 + test_data['X']) / 255. 54 | test_x = test_x.transpose((3, 0, 1, 2)) 55 | test_x = test_x.reshape((test_x.shape[0], -1)) 56 | test_y = test_data['y'].flatten().astype(np.int32) 57 | test_y[test_y == 10] = 0 58 | 59 | np.save('{}/train_images'.format(FLAGS.data__dir), train_x) 60 | np.save('{}/train_labels'.format(FLAGS.data__dir), train_y) 61 | np.save('{}/test_images'.format(FLAGS.data__dir), test_x) 62 | np.save('{}/test_labels'.format(FLAGS.data__dir), test_y) 63 | 64 | 65 | def load_svhn(): 66 | maybe_download_and_extract() 67 | train_images = np.load('{}/train_images.npy'.format(FLAGS.data__dir)).astype(np.float32) 68 | train_labels = np.load('{}/train_labels.npy'.format(FLAGS.data__dir)).astype(np.float32) 69 | test_images = np.load('{}/test_images.npy'.format(FLAGS.data__dir)).astype(np.float32) 70 | test_labels = np.load('{}/test_labels.npy'.format(FLAGS.data__dir)).astype(np.float32) 71 | return (train_images, train_labels), (test_images, test_labels) 72 | 73 | 74 | def prepare_dataset(): 75 | (train_images, train_labels), (test_images, test_labels) = load_svhn() 76 | dirpath = os.path.join(FLAGS.data__dir, 'seed' + str(FLAGS.dataset_seed)) 77 | if not os.path.exists(dirpath): 78 | os.makedirs(dirpath) 79 | 80 | rng = np.random.RandomState(FLAGS.dataset_seed) 81 | rand_ix = rng.permutation(NUM_EXAMPLES_TRAIN) 82 | print(rand_ix) 83 | _train_images, _train_labels = train_images[rand_ix], train_labels[rand_ix] 84 | 85 | labeled_ind = np.arange(FLAGS.num_labeled_examples) 86 | labeled_train_images, labeled_train_labels = _train_images[labeled_ind], _train_labels[labeled_ind] 87 | _train_images = np.delete(_train_images, labeled_ind, 0) 88 | _train_labels = np.delete(_train_labels, labeled_ind, 0) 89 | convert_images_and_labels(labeled_train_images, 90 | labeled_train_labels, 91 | os.path.join(dirpath, 'labeled_train.tfrecords')) 92 | convert_images_and_labels(train_images, train_labels, 93 | os.path.join(dirpath, 'unlabeled_train.tfrecords')) 94 | convert_images_and_labels(test_images, 95 | test_labels, 96 | os.path.join(dirpath, 'test.tfrecords')) 97 | 98 | # Construct dataset for validation 99 | train_images_valid, train_labels_valid = labeled_train_images, labeled_train_labels 100 | test_images_valid, test_labels_valid = \ 101 | _train_images[:FLAGS.num_valid_examples], _train_labels[:FLAGS.num_valid_examples] 102 | unlabeled_train_images_valid = np.concatenate( 103 | (train_images_valid, _train_images[FLAGS.num_valid_examples:]), axis=0) 104 | unlabeled_train_labels_valid = np.concatenate( 105 | (train_labels_valid, _train_labels[FLAGS.num_valid_examples:]), axis=0) 106 | convert_images_and_labels(train_images_valid, 107 | train_labels_valid, 108 | os.path.join(dirpath, 'labeled_train_val.tfrecords')) 109 | convert_images_and_labels(unlabeled_train_images_valid, 110 | unlabeled_train_labels_valid, 111 | os.path.join(dirpath, 'unlabeled_train_val.tfrecords')) 112 | convert_images_and_labels(test_images_valid, 113 | test_labels_valid, 114 | os.path.join(dirpath, 'test_val.tfrecords')) 115 | 116 | 117 | def inputs(batch_size=100, 118 | train=True, validation=False, 119 | shuffle=True, num_epochs=None): 120 | if validation: 121 | if train: 122 | filenames = ['labeled_train_val.tfrecords'] 123 | num_examples = FLAGS.num_labeled_examples 124 | else: 125 | filenames = ['test_val.tfrecords'] 126 | num_examples = FLAGS.num_valid_examples 127 | else: 128 | if train: 129 | filenames = ['labeled_train.tfrecords'] 130 | num_examples = FLAGS.num_labeled_examples 131 | else: 132 | filenames = ['test.tfrecords'] 133 | num_examples = NUM_EXAMPLES_TEST 134 | 135 | filenames = [os.path.join('seed' + str(FLAGS.dataset_seed), filename) for filename in filenames] 136 | filename_queue = generate_filename_queue(filenames, FLAGS.data__dir, num_epochs) 137 | image, label = read(filename_queue) 138 | image = transform(tf.cast(image, tf.float32)) if train else image 139 | return generate_batch([image, label], num_examples, batch_size, shuffle) 140 | 141 | 142 | def unlabeled_inputs(batch_size=100, 143 | validation=False, 144 | shuffle=True): 145 | if validation: 146 | filenames = ['unlabeled_train_val.tfrecords'] 147 | num_examples = NUM_EXAMPLES_TRAIN - FLAGS.num_valid_examples 148 | else: 149 | filenames = ['unlabeled_train.tfrecords'] 150 | num_examples = NUM_EXAMPLES_TRAIN 151 | 152 | filenames = [os.path.join('seed' + str(FLAGS.dataset_seed), filename) for filename in filenames] 153 | filename_queue = generate_filename_queue(filenames, data_dir=FLAGS.data__dir) 154 | image, label = read(filename_queue) 155 | image = transform(tf.cast(image, tf.float32)) 156 | return generate_batch([image], num_examples, batch_size, shuffle) 157 | 158 | 159 | def main(argv): 160 | prepare_dataset() 161 | 162 | 163 | if __name__ == "__main__": 164 | FLAGS = tf.app.flags.FLAGS 165 | tf.app.flags.DEFINE_string('data__dir', './data/svhn', "") 166 | tf.app.flags.DEFINE_integer('dataset_seed', 1, "dataset seed") 167 | tf.app.flags.DEFINE_integer('num_labeled_examples', 1000, "The number of labeled examples") 168 | tf.app.flags.DEFINE_integer('num_valid_examples', 1000, "The number of validation examples") 169 | 170 | tf.app.run() 171 | -------------------------------------------------------------------------------- /lvat/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | 4 | def restore(sess, scope, ckpt): 5 | vars_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope) 6 | saver = tf.train.Saver(vars_to_restore) 7 | saver.restore(sess, os.path.join(ckpt,"model.ckpt")) 8 | return 9 | 10 | def is_inited_or_not(sess): 11 | print('is_inited_or_not() was called') 12 | for var in tf.global_variables(): 13 | try: 14 | sess.run(var) 15 | print('inited:', var.name) 16 | except tf.errors.FailedPreconditionError: 17 | print('uninited:', var.name) 18 | return 19 | 20 | def init_uninitialized_vars(sess): 21 | 22 | uninitialized_vars = [] 23 | for var in tf.global_variables(): 24 | try: 25 | sess.run(var) 26 | except tf.errors.FailedPreconditionError: 27 | uninitialized_vars.append(var) 28 | 29 | print('uninitialized_vars to be initialized right now >>>>>>>>>>>>>>>>>>>>>>') 30 | for var in uninitialized_vars: 31 | print(var) 32 | op_init = tf.variables_initializer(uninitialized_vars) 33 | return op_init 34 | 35 | print("... do init variables in sanitizer") 36 | self.sess.run(self.op_init) 37 | 38 | return 39 | -------------------------------------------------------------------------------- /lvat/vat.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy 3 | import sys, os 4 | 5 | import layers_vat as L 6 | import cnn 7 | 8 | FLAGS = tf.app.flags.FLAGS 9 | 10 | tf.app.flags.DEFINE_float('epsilon', 1.0, "norm length for (virtual) adversarial training ") 11 | tf.app.flags.DEFINE_integer('num_power_iterations', 1, "the number of power iterations") 12 | tf.app.flags.DEFINE_float('xi', 1e-6, "small constant for finite difference") 13 | 14 | SCOPE_CLASSIFIER = 'scope_classifier' 15 | 16 | 17 | def logit(x, is_training=True, update_batch_stats=True, stochastic=True, seed=1234): 18 | return cnn.logit(x, is_training=is_training, 19 | update_batch_stats=update_batch_stats, 20 | stochastic=stochastic, 21 | seed=seed) 22 | 23 | 24 | def forward(x, decoder=None, is_training=True, update_batch_stats=True, seed=1234): 25 | 26 | if decoder is not None: 27 | # when decoder is given, input x is actually z. 28 | 29 | if FLAGS.ae_type == 'Glow': 30 | SCOPE_DECODER = "scope_glow" 31 | 32 | with tf.variable_scope(SCOPE_DECODER, reuse=tf.AUTO_REUSE): 33 | x = decoder(x) 34 | 35 | else: 36 | if FLAGS.ae_type == 'VAE': 37 | SCOPE_DECODER = "scope_vae" 38 | elif FLAGS.ae_type == 'AE': 39 | SCOPE_DECODER = "scope_ae" 40 | elif FLAGS.ae_type == 'DAE': 41 | SCOPE_DECODER = "scope_dae" 42 | 43 | with tf.variable_scope(SCOPE_DECODER, reuse=tf.AUTO_REUSE): 44 | x = decoder(x, is_train=False) 45 | 46 | with tf.variable_scope(SCOPE_CLASSIFIER, reuse=tf.AUTO_REUSE): 47 | if is_training: 48 | return logit(x, is_training=True, 49 | update_batch_stats=update_batch_stats, 50 | stochastic=True, seed=seed) 51 | else: 52 | return logit(x, is_training=False, 53 | update_batch_stats=update_batch_stats, 54 | stochastic=False, seed=seed) 55 | 56 | def forward(x, decoder=None, is_training=True, update_batch_stats=True, seed=1234): 57 | 58 | if decoder is not None: 59 | # when decoder is given, input x is actually z. 60 | 61 | if FLAGS.ae_type == 'Glow': 62 | # x must be (y,logdet,z) 63 | 64 | SCOPE_DECODER = "scope_glow" 65 | 66 | with tf.variable_scope(SCOPE_DECODER, reuse=tf.AUTO_REUSE): 67 | x = decoder(x) 68 | 69 | else: 70 | if FLAGS.ae_type == 'VAE': 71 | SCOPE_DECODER = "scope_vae" 72 | elif FLAGS.ae_type == 'AE': 73 | SCOPE_DECODER = "scope_ae" 74 | elif FLAGS.ae_type == 'DAE': 75 | SCOPE_DECODER = "scope_dae" 76 | 77 | with tf.variable_scope(SCOPE_DECODER, reuse=tf.AUTO_REUSE): 78 | x = decoder(x, is_train=False) 79 | 80 | with tf.variable_scope(SCOPE_CLASSIFIER, reuse=tf.AUTO_REUSE): 81 | if is_training: 82 | return logit(x, is_training=True, 83 | update_batch_stats=update_batch_stats, 84 | stochastic=True, seed=seed) 85 | else: 86 | return logit(x, is_training=False, 87 | update_batch_stats=update_batch_stats, 88 | stochastic=False, seed=seed) 89 | 90 | 91 | def get_normalized_vector(d): 92 | d /= (1e-12 + tf.reduce_max(tf.abs(d), list(range(1, len(d.get_shape()))), keep_dims=True)) 93 | d /= tf.sqrt(1e-6 + tf.reduce_sum(tf.pow(d, 2.0), list(range(1, len(d.get_shape()))), keep_dims=True)) 94 | return d 95 | 96 | 97 | def generate_virtual_adversarial_perturbation_glow(latent, logit, decoder, is_training=True): 98 | 99 | y, logdet, z = latent 100 | d_y = tf.random_normal(shape=tf.shape(y)) 101 | d_z = tf.random_normal(shape=tf.shape(z)) 102 | 103 | for _ in range(FLAGS.num_power_iterations): 104 | d_y = FLAGS.xi * get_normalized_vector(d_y) 105 | d_z = FLAGS.xi * get_normalized_vector(d_z) 106 | logit_p = logit 107 | logit_m = forward((y+d_y, logdet, z+d_z), decoder, update_batch_stats=False, is_training=is_training) 108 | dist = L.kl_divergence_with_logit(logit_p, logit_m) 109 | grad_y = tf.gradients(dist, [d_y], aggregation_method=2)[0] 110 | grad_z = tf.gradients(dist, [d_z], aggregation_method=2)[0] 111 | d_y = tf.stop_gradient(grad_y) 112 | d_z = tf.stop_gradient(grad_z) 113 | 114 | return (FLAGS.epsilon * get_normalized_vector(d_y), FLAGS.epsilon * get_normalized_vector(d_z)) 115 | 116 | def generate_virtual_adversarial_perturbation(x, logit, decoder=None, is_training=True): 117 | 118 | # when decoder is given, input x is actually z. 119 | 120 | d = tf.random_normal(shape=tf.shape(x)) 121 | 122 | for _ in range(FLAGS.num_power_iterations): 123 | d = FLAGS.xi * get_normalized_vector(d) 124 | logit_p = logit 125 | logit_m = forward(x + d, decoder, update_batch_stats=False, is_training=is_training) 126 | dist = L.kl_divergence_with_logit(logit_p, logit_m) 127 | grad = tf.gradients(dist, [d], aggregation_method=2)[0] 128 | d = tf.stop_gradient(grad) 129 | 130 | return FLAGS.epsilon * get_normalized_vector(d) 131 | 132 | 133 | def virtual_adversarial_loss_glow(latent, logit, decoder, is_training=True, name="vat_loss"): 134 | 135 | y, logdet, z = latent 136 | 137 | r_vadv_y, r_vadv_z = generate_virtual_adversarial_perturbation_glow(latent, logit, decoder, is_training=is_training) 138 | logit = tf.stop_gradient(logit) 139 | logit_p = logit 140 | logit_m = forward((y+r_vadv_y, logdet, z+r_vadv_z), decoder, update_batch_stats=False, is_training=is_training) 141 | loss = L.kl_divergence_with_logit(logit_p, logit_m) 142 | return tf.identity(loss, name=name), r_vadv_y, r_vadv_z 143 | 144 | 145 | def virtual_adversarial_loss(x, logit, decoder=None, is_training=True, name="vat_loss"): 146 | # when decoder is given, input x is actually z. 147 | r_vadv = generate_virtual_adversarial_perturbation(x, logit, decoder, is_training=is_training) 148 | logit = tf.stop_gradient(logit) 149 | logit_p = logit 150 | logit_m = forward(x + r_vadv, decoder, update_batch_stats=False, is_training=is_training) 151 | loss = L.kl_divergence_with_logit(logit_p, logit_m) 152 | return tf.identity(loss, name=name), r_vadv 153 | 154 | 155 | def generate_adversarial_perturbation(x, loss): 156 | grad = tf.gradients(loss, [x], aggregation_method=2)[0] 157 | grad = tf.stop_gradient(grad) 158 | return FLAGS.epsilon * get_normalized_vector(grad) 159 | 160 | 161 | def adversarial_loss(x, y, loss, is_training=True, name="at_loss"): 162 | r_adv = generate_adversarial_perturbation(x, loss) 163 | logit = forward(x + r_adv, is_training=is_training, update_batch_stats=False) 164 | loss = L.ce_loss(logit, y) 165 | return loss 166 | 167 | def pi_loss(logit_t, logit_s, name="pi_loss"): 168 | logit_t = tf.stop_gradient(logit_t) 169 | loss = tf.sqrt(tf.reduce_mean(tf.square(tf.subtract(logit_t, logit_s))) + FLAGS.epsilon) 170 | return tf.identity(loss, name=name) 171 | 172 | def sampler(mu, logsigma): 173 | sigma = tf.exp(logsigma) 174 | epsilon = FLAGS.epsilon * tf.truncated_normal(tf.shape(mu), mean=0, stddev=1.) 175 | return mu + sigma*epsilon 176 | 177 | def distort(x): 178 | def _distort(a_image): 179 | """ 180 | bounding_boxes: A Tensor of type float32. 181 | 3-D with shape [batch, N, 4] describing the N bounding boxes associated with the image. 182 | Bounding boxes are supplied and returned as [y_min, x_min, y_max, x_max] 183 | """ 184 | if FLAGS.is_aug_trans: 185 | a_image = tf.pad(a_image, [[2, 2], [2, 2], [0, 0]]) 186 | a_image = tf.random_crop(a_image, [32,32,3]) 187 | 188 | if FLAGS.is_aug_flip: 189 | a_image = tf.image.random_flip_left_right(a_image) 190 | 191 | if FLAGS.is_aug_rotate: 192 | from math import pi 193 | radian = tf.random_uniform(shape=(), minval=0, maxval=360) * pi / 180 194 | a_image = tf.contrib.image.rotate(a_image, radian, interpolation='BILINEAR') 195 | 196 | if FLAGS.is_aug_color: 197 | a_image = tf.image.random_brightness(a_image, max_delta=0.2) 198 | a_image = tf.image.random_contrast( a_image, lower=0.2, upper=1.8 ) 199 | a_image = tf.image.random_hue(a_image, max_delta=0.2) 200 | 201 | if FLAGS.is_aug_crop: 202 | # shape: [1, 1, 4] 203 | bounding_boxes = tf.constant([[[1/10, 1/10, 9/10, 9/10]]], dtype=tf.float32) 204 | 205 | begin, size, _ = tf.image.sample_distorted_bounding_box( 206 | (32,32,3), bounding_boxes, 207 | min_object_covered=(9.8/10.0), 208 | aspect_ratio_range=[9.5/10.0, 10.0/9.5]) 209 | 210 | a_image = tf.slice(a_image, begin, size) 211 | """ for the purpose of distorting not use tf.image.resize_image_with_crop_or_pad under """ 212 | a_image = tf.image.resize_images(a_image, [32,32]) 213 | """ due to the size of channel returned from tf.image.resize_images is not being given, 214 | specify it manually. """ 215 | a_image = tf.reshape(a_image, [32,32,3]) 216 | return a_image 217 | 218 | """ process batch times in parallel """ 219 | return tf.map_fn( _distort, x) 220 | 221 | 222 | -------------------------------------------------------------------------------- /vae/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /vae/VAE.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import sys, os 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/util') 5 | from layers import Layers 6 | from losses import LossFunctions 7 | import config as c 8 | 9 | #from tflearn.layers.normalization import batch_normalization 10 | 11 | class VAE(object): 12 | 13 | def __init__(self, resource): 14 | 15 | 16 | """ data and external toolkits """ 17 | self.d = resource.dh # dataset manager 18 | self.ls = Layers() 19 | self.lf = LossFunctions(self.ls, self.d, self.encoder) 20 | 21 | """ placeholders defined outside""" 22 | if c.DO_TRAIN: 23 | self.lr = resource.ph['lr'] 24 | 25 | 26 | def encoder(self, h, is_train, y=None): 27 | 28 | if is_train: 29 | _d = self.d 30 | #_ = tf.summary.image('image', tf.reshape(h, [-1, _d.h, _d.w, _d.c]), 10) 31 | 32 | scope = 'e_1' 33 | h = self.ls.conv2d(scope+'_1', h, 128, filter_size=(2,2), strides=(1,2,2,1), padding="VALID") 34 | h = tf.layers.batch_normalization(h, training=is_train, name=scope) 35 | h = tf.nn.relu(h) 36 | 37 | scope = 'e_2' 38 | h = self.ls.conv2d(scope+'_1', h, 256, filter_size=(2,2), strides=(1,2,2,1), padding="VALID") 39 | h = tf.layers.batch_normalization(h, training=is_train, name=scope) 40 | h = tf.nn.relu(h) 41 | 42 | scope = 'e_3' 43 | h = self.ls.conv2d(scope+'_1', h, 512, filter_size=(2,2), strides=(1,2,2,1), padding="VALID") 44 | h = tf.layers.batch_normalization(h, training=is_train, name=scope) 45 | #h = tf.nn.relu(h) 46 | h = tf.nn.tanh(h) 47 | 48 | # -> (b, 4, 4, 512) 49 | 50 | print('h:', h) 51 | #h = tf.reshape(h, (c.BATCH_SIZE, -1)) 52 | h = tf.reshape(h, (-1, 4*4*512)) 53 | print('h:', h) 54 | 55 | #sys.exit('aa') 56 | h = self.ls.denseV2('top_of_encoder', h, c.Z_SIZE*2, activation=None) 57 | print('h:', h) 58 | return self.ls.vae_sampler_w_feature_slice( h, c.Z_SIZE) 59 | 60 | def decoder(self, h, is_train): 61 | 62 | scope = 'top_of_decoder' 63 | #h = self.ls.denseV2(scope, h, 128, activation=self.ls.lrelu) 64 | h = self.ls.denseV2(scope, h, 512, activation=self.ls.lrelu) 65 | print('h:', scope, h) 66 | 67 | h = tf.reshape(h, (-1, 4,4,32)) 68 | print('h:', scope, h) 69 | 70 | scope = 'd_1' 71 | h = self.ls.deconv2d(scope+'_1', h, 512, filter_size=(2,2)) 72 | h = tf.layers.batch_normalization(h, training=is_train, name=scope) 73 | h = tf.nn.relu(h) 74 | print('h:', scope, h) 75 | 76 | scope = 'd_2' 77 | h = self.ls.deconv2d(scope+'_2', h, 256, filter_size=(2,2)) 78 | h = tf.layers.batch_normalization(h, training=is_train, name=scope) 79 | h = tf.nn.relu(h) 80 | print('h:', scope, h) 81 | 82 | scope = 'd_3' 83 | h = self.ls.deconv2d(scope+'_3', h, 128, filter_size=(2,2)) 84 | h = tf.layers.batch_normalization(h, training=is_train, name=scope) 85 | h = tf.nn.relu(h) 86 | print('h:', scope, h) 87 | 88 | scope = 'd_4' 89 | h = self.ls.conv2d(scope+'_4', h, 3, filter_size=(1,1), strides=(1,1,1,1), padding="VALID", activation=tf.nn.sigmoid) 90 | print('h:', scope, h) 91 | 92 | return h 93 | 94 | 95 | def build_graph_train(self, x_l, y_l): 96 | 97 | o = dict() # output 98 | loss = 0 99 | 100 | if c.IS_AUGMENTATION_ENABLED: 101 | x_l = distorted = self.distort(x_l) 102 | 103 | if c.IS_AUG_NOISE_TRUE: 104 | x_l = self.ls.get_corrupted(x_l, 0.15) 105 | 106 | z, mu, logsigma = self.encoder(x_l, is_train=True, y=y_l) 107 | 108 | x_reconst = self.decoder(z, is_train=True) 109 | 110 | """ p(x|z) Reconstruction Loss """ 111 | o['Lr'] = self.lf.get_loss_pxz(x_reconst, x_l, 'Bernoulli') 112 | o['x_reconst'] = x_reconst 113 | o['x'] = x_l 114 | loss += o['Lr'] 115 | 116 | 117 | """ VAE KL-Divergence Loss """ 118 | LAMBDA_VAE = 0.1 119 | o['mu'], o['logsigma'] = mu, logsigma 120 | # work around. [ToDo] make sure the root cause that makes kl loss inf 121 | #logsigma = tf.clip_by_norm( logsigma, 10) 122 | o['Lz'] = self.lf.get_loss_vae(c.Z_SIZE, mu,logsigma, _lambda=0.0) 123 | loss += LAMBDA_VAE * o['Lz'] 124 | 125 | 126 | """ set losses """ 127 | o['loss'] = loss 128 | self.o_train = o 129 | 130 | """ set optimizer """ 131 | optimizer = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) 132 | grads = optimizer.compute_gradients(loss) 133 | for i,(g,v) in enumerate(grads): 134 | if g is not None: 135 | #g = tf.Print(g, [g], "g %s = "%(v)) 136 | grads[i] = (tf.clip_by_norm(g,5),v) # clip gradients 137 | else: 138 | print('g is None:', v) 139 | v = tf.Print(v, [v], "v = ", summarize=10000) 140 | 141 | 142 | # update ema in batch_normalization 143 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 144 | self.op = optimizer.apply_gradients(grads) # return train_op 145 | 146 | 147 | def build_graph_test(self, x_l, y_l): 148 | 149 | o = dict() # output 150 | loss = 0 151 | 152 | 153 | z, mu, logsigma = self.encoder(x_l, is_train=False, y=y_l) 154 | 155 | x_reconst = self.decoder(mu, is_train=False) 156 | o['x_reconst'] = x_reconst 157 | o['x'] = x_l 158 | #o['Lr'] = self.lf.get_loss_pxz(x_reconst, x_l, 'LeastSquare') 159 | o['Lr'] = self.lf.get_loss_pxz(x_reconst, x_l, 'Bernoulli') 160 | #o['Lr'] = self.lf.get_loss_pxz(x_reconst, x_l, 'DiscretizedLogistic') 161 | #o['Lr'] = tf.reduce_mean(tf.keras.losses.binary_crossentropy(x_l, x_reconst)) 162 | loss += o['Lr'] 163 | 164 | 165 | """ set losses """ 166 | o['loss'] = loss 167 | self.o_test = o 168 | 169 | 170 | def distort(self, x): 171 | 172 | """ 173 | maybe helpful http://www.redhub.io/Tensorflow/tensorflow-models/src/master/inception/inception/image_processing.py 174 | """ 175 | _d = self.d 176 | 177 | def _distort(a_image): 178 | """ 179 | bounding_boxes: A Tensor of type float32. 180 | 3-D with shape [batch, N, 4] describing the N bounding boxes associated with the image. 181 | Bounding boxes are supplied and returned as [y_min, x_min, y_max, x_max] 182 | """ 183 | if c.IS_AUG_TRANS_TRUE: 184 | a_image = tf.pad(a_image, [[2, 2], [2, 2], [0, 0]]) 185 | a_image = tf.random_crop(a_image, [_d.h, _d.w, _d.c]) 186 | 187 | if c.IS_AUG_FLIP_TRUE: 188 | a_image = tf.image.random_flip_left_right(a_image) 189 | 190 | if c.IS_AUG_ROTATE_TRUE: 191 | from math import pi 192 | radian = tf.random_uniform(shape=(), minval=0, maxval=360) * pi / 180 193 | a_image = tf.contrib.image.rotate(a_image, radian, interpolation='BILINEAR') 194 | 195 | if c.IS_AUG_COLOR_TRUE: 196 | a_image = tf.image.random_brightness(a_image, max_delta=0.2) 197 | a_image = tf.image.random_contrast( a_image, lower=0.2, upper=1.8 ) 198 | a_image = tf.image.random_hue(a_image, max_delta=0.2) 199 | 200 | if c.IS_AUG_CROP_TRUE: 201 | # shape: [1, 1, 4] 202 | bounding_boxes = tf.constant([[[1/10, 1/10, 9/10, 9/10]]], dtype=tf.float32) 203 | 204 | begin, size, _ = tf.image.sample_distorted_bounding_box( 205 | (_d.h, _d.w, _d.c), bounding_boxes, 206 | min_object_covered=(9.8/10.0), 207 | aspect_ratio_range=[9.5/10.0, 10.0/9.5]) 208 | 209 | a_image = tf.slice(a_image, begin, size) 210 | """ for the purpose of distorting not use tf.image.resize_image_with_crop_or_pad under """ 211 | a_image = tf.image.resize_images(a_image, [_d.h, _d.w]) 212 | """ due to the size of channel returned from tf.image.resize_images is not being given, 213 | specify it manually. """ 214 | a_image = tf.reshape(a_image, [_d.h, _d.w, _d.c]) 215 | return a_image 216 | 217 | """ process batch times in parallel """ 218 | return tf.map_fn( _distort, x) 219 | -------------------------------------------------------------------------------- /vae/build_AE.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import sys, os, time 4 | from collections import namedtuple 5 | from tqdm import tqdm 6 | 7 | import config as c 8 | 9 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/util/') 10 | 11 | 12 | np.set_printoptions(threshold=np.inf) 13 | 14 | 15 | def test(epoch, m, sess, resource): 16 | 17 | if c.IS_TF_QUEUE_AVAILABLE: xy = None 18 | 19 | time_start = time.time() 20 | 21 | Lr = [] 22 | 23 | for i in tqdm(range(m.d.n_batches_test), leave=False): 24 | 25 | if not c.IS_TF_QUEUE_AVAILABLE: xy = m.d.get_next_batch(i, False) 26 | r = fetch_from_test_graph(m, sess, resource, xy) 27 | 28 | Lr.append( r['Lr']) 29 | 30 | Lr = np.mean(Lr, axis=0) 31 | 32 | elapsed_time = time.time() - time_start 33 | o = ("validation: epoch:%d, loss: %5f, time:%.3f" % (epoch, Lr, elapsed_time )) 34 | print(o) 35 | 36 | IS_CHECK_RECONST_IMG = True 37 | if IS_CHECK_RECONST_IMG: 38 | from eval import draw_x 39 | x, x_reconst = r['x'], r['x_reconst'] 40 | if c.BATCH_SIZE > 50: x, x_reconst = x[:50], x_reconst[:50] 41 | 42 | draw_x(x, x_reconst, y=None, filename='debug_in_training_epoch_%d'%(epoch)) 43 | 44 | return 45 | 46 | def eval(m, sess, resource, xy): 47 | 48 | x,y = xy 49 | b = c.BATCH_SIZE 50 | n = len(xy[1]) 51 | n_batches = n // b 52 | Lr, x_reconst = [],[] 53 | 54 | assert n_batches>0, 'something wrong, check batch_size and n_samples' 55 | for i in tqdm(range(n_batches), leave=False): 56 | r = fetch_from_test_graph(m, sess, resource, (x[i*b:(i+1)*b], y[i*b:(i+1)*b])) 57 | #print(r) 58 | Lr.append( r['Lr']) 59 | x_reconst.extend( r['x_reconst']) 60 | 61 | Lr = np.mean(Lr, axis=0) 62 | o = ("loss: %5f " % (Lr)) 63 | print(o) 64 | x_reconst = np.array(x_reconst) 65 | 66 | return x_reconst, Lr 67 | 68 | def fetch_from_test_graph(m, sess, resource, xy=None): 69 | 70 | if xy is None: feed_dict = None 71 | else: 72 | (_x, _y) = xy 73 | feed_dict = {resource.ph['x_test']:_x , resource.ph['y_test']:_y} 74 | return sess.run(m.o_test, feed_dict) 75 | 76 | 77 | 78 | def build(ckpt=None, graph=None): 79 | #with tf.Graph().as_default() as graph_ae: 80 | 81 | ########################################### 82 | """ Load Data """ 83 | ########################################### 84 | 85 | ph = {} 86 | if c.IS_TF_QUEUE_AVAILABLE: 87 | from HandleIIDDataTFRecord import HandleIIDDataTFRecord 88 | d = HandleIIDDataTFRecord() 89 | (x_train, y_train), x, (x_test, y_test) = d.get_tfrecords(c.TEST_IDXES) 90 | 91 | else: 92 | from HandleImageDataNumpy import HandleImageDataNumpy 93 | d = HandleImageDataNumpy(c.FLAGS.dataset, c.BATCH_SIZE) 94 | 95 | ph['x_train'] = x_train = tf.placeholder(tf.float32, shape=[None, d.h, d.w, d.c], name="ph_x_train") 96 | #ph['x_train'] = x_train = tf.placeholder(tf.float32, shape=[c.BATCH_SIZE, d.h, d.w, d.c], name="ph_x_train") 97 | ph['y_train'] = y_train = tf.placeholder(tf.float32, shape=[None, d.l], name="ph_y_train") 98 | ph['x'] = x = tf.placeholder(tf.float32, shape=[None, d.h, d.w, d.c], name="ph_x") 99 | ph['x_test'] = x_test = tf.placeholder(tf.float32, shape=[None, d.h, d.w, d.c], name="ph_x_test") 100 | #ph['x_test'] = x_test = tf.placeholder(tf.float32, shape=[c.BATCH_SIZE, d.h, d.w, d.c], name="ph_x_test") 101 | ph['y_test'] = y_test = tf.placeholder(tf.float32, shape=[None, d.l], name="ph_y_test") 102 | 103 | 104 | ########################################### 105 | """ Build Model Graphs """ 106 | ########################################### 107 | ph['lr'] = tf.placeholder(tf.float32, shape=(), name="ph_learning_rate") 108 | 109 | Resource = namedtuple('Resource', ('dh', 'merged', 'saver', 'ph')) 110 | resource = Resource(dh=d, merged=None, saver=None, ph=ph) 111 | 112 | if c.GENERATOR_IS == 'VAE': 113 | from VAE import VAE 114 | m = VAE( resource ) 115 | scope_name = "scope_vae" 116 | 117 | elif c.GENERATOR_IS == 'AE': 118 | from AE import AE 119 | m = AE( resource ) 120 | scope_name = "scope_ae" 121 | 122 | elif c.GENERATOR_IS == 'DAE': 123 | from DAE import DAE 124 | m = DAE( resource ) 125 | scope_name = "scope_dae" 126 | 127 | else: 128 | raise ValueError('invalid arg: c.GENERATOR_IS is %s '%(c.GENERATOR_IS )) 129 | 130 | with tf.variable_scope(scope_name) as scope: 131 | 132 | if c.DO_TRAIN : 133 | print('... now building the graph for training.') 134 | m.build_graph_train(x_train,y_train) # the third one is a dummy for future 135 | scope.reuse_variables() 136 | 137 | if c.DO_TEST : 138 | print('... now building the graph for test.') 139 | m.build_graph_test(x_test,y_test) 140 | 141 | 142 | ########################################### 143 | """ Init """ 144 | ########################################### 145 | init_op = tf.global_variables_initializer() 146 | #for v in tf.all_variables(): print("[DEBUG] %s : %s" % (v.name,v.get_shape())) 147 | 148 | saver = tf.train.Saver() 149 | config = tf.ConfigProto() 150 | config.gpu_options.allow_growth = True 151 | config.gpu_options.allocator_type = 'BFC' 152 | sess = tf.Session(config=config, graph=graph) 153 | 154 | 155 | if c.FLAGS.restore: 156 | if not ckpt: ckpt = c.FLAGS.file_ckpt 157 | print("... restore with:", ckpt) 158 | saver.restore(sess, ckpt) 159 | else: 160 | sess.run(init_op) 161 | 162 | resource = resource._replace(merged=tf.summary.merge_all(), saver=saver) 163 | #tf.get_default_graph().finalize() 164 | 165 | return m, sess, resource 166 | 167 | def train(m, sess, resource): 168 | 169 | print('... start training') 170 | 171 | _lr, ratio = c.STARTER_LEARNING_RATE, 1.0 172 | _barrier_depth, barrier_growth = 0.,0. 173 | for epoch in range(1, c.N_EPOCHS+1): 174 | 175 | 176 | loss, Lr, Lz = [],[],[] 177 | for i in range(m.d.n_batches_train): 178 | 179 | if c.IS_TF_QUEUE_AVAILABLE: 180 | feed_dict = {resource.ph['lr']:_lr, 181 | } 182 | else: 183 | _x, _y = m.d.get_next_batch(i, True) 184 | feed_dict = {resource.ph['lr']:_lr, resource.ph['x_train']:_x , resource.ph['y_train']:_y, 185 | } 186 | 187 | """ do update """ 188 | time_start = time.time() 189 | #summary, r, op, current_lr = sess.run([resource.merged, m.o_train, m.op, m.lr], feed_dict=feed_dict) 190 | r, op, current_lr = sess.run([ m.o_train, m.op, m.lr], feed_dict=feed_dict) 191 | elapsed_time = time.time() - time_start 192 | 193 | loss.append(r['loss']) 194 | if c.GENERATOR_IS == 'VAE': 195 | Lr.append(r['Lr']) 196 | Lz.append(r['Lz']) 197 | 198 | #if i == 0: 199 | # print('debug:', r['logit'][-1]) 200 | 201 | #if i % 5 == 0 and i != 0: 202 | # break 203 | 204 | if ~np.isfinite(r['loss']).all(): 205 | print('mu:', r['mu']) 206 | print('logsigma', r['logsigma']) 207 | print(" iter:%2d, Lr: %.5f, Lz: %.5f, time:%.3f" % \ 208 | (i, np.mean(np.array(r['Lr'])), np.mean(np.array(r['Lz'])), elapsed_time)) 209 | print('mu:', np.mean(np.array(r['mu']))) 210 | print('logsigma:', np.mean(np.array(r['logsigma']))) 211 | sys.exit('nan was detected in loss') 212 | 213 | 214 | #if i % 100 == 0 and i != 0: 215 | if i % 500 == 0 and i != 0: 216 | #print('debug:', r['x']) 217 | #print('debug:', r['x_reconst']) 218 | 219 | # Debug 220 | """ 221 | import matplotlib.pyplot as plt 222 | plt.figure(figsize=(20,20)) 223 | for i in range(30): 224 | plt.subplot(5,6,i+1) 225 | plt.imshow(r['debug'][i]) 226 | plt.show() 227 | """ 228 | 229 | if c.GENERATOR_IS == 'VAE': 230 | print(" iter:%2d, loss: %.5f, Lr: %.5f, Lz: %.5f, time:%.3f" % \ 231 | (i, np.mean(np.array(loss)), np.mean(np.array(Lr)), np.mean(np.array(Lz)), elapsed_time)) 232 | else: 233 | print(" iter:%2d, loss: %.5f, time:%.3f" % \ 234 | (i, np.mean(np.array(loss)), elapsed_time)) 235 | 236 | print("training: epoch:%d, loss: %.5f" % \ 237 | ((epoch, np.mean(np.array(loss)), )), flush=True) 238 | 239 | """ test """ 240 | if c.DO_TEST and epoch % 1 == 0: 241 | test(epoch, m, sess, resource) 242 | 243 | """ save """ 244 | #if epoch % 5 == 0: 245 | if epoch % 1 == 0: 246 | print("Model saved in file: %s" % resource.saver.save(sess,c.FLAGS.file_ckpt)) 247 | 248 | """ learning rate decay""" 249 | if (epoch % c.DECAY_INTERVAL == 0) and (epoch > c.DECAY_AFTER): 250 | ratio *= c.DECAY_FACTOR 251 | _lr = c.STARTER_LEARNING_RATE * ratio 252 | #print('lr decaying is scheduled. epoch:%d, lr:%f <= %f' % ( epoch, _lr, current_lr)) 253 | 254 | if __name__ == "__main__": 255 | m, sess, resource = build() 256 | 257 | if c.IS_TF_QUEUE_AVAILABLE: tf.train.start_queue_runners(sess=sess) 258 | if c.DO_TRAIN: train(m, sess, resource) 259 | print('... now testing') 260 | test(0, m, sess, resource) 261 | print('... done.') 262 | sess.close() 263 | -------------------------------------------------------------------------------- /vae/config.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os,sys 4 | 5 | IS_TF_QUEUE_AVAILABLE = True 6 | 7 | DO_TRAIN = True 8 | DO_TEST = True 9 | DO_VIEW = False 10 | DO_TEST_DEFENCE = False 11 | DO_REGION_CLASSIFY = False 12 | 13 | GENERATOR_IS = "AE" 14 | GENERATOR_IS = "VAEGAN" 15 | GENERATOR_IS = "GAN" 16 | GENERATOR_IS = "VAE" 17 | 18 | IS_WGAN = False 19 | IS_BN_ENABLED = True 20 | 21 | Z_SIZE = 200 22 | Z_SIZE = 128 23 | 24 | BATCH_SIZE = 2 25 | BATCH_SIZE = 9 26 | BATCH_SIZE = 256 # for VAE\GAN purpose 27 | 28 | 29 | if IS_TF_QUEUE_AVAILABLE: 30 | BATCH_SIZE = 32 # for VAT(SSL) 31 | BATCH_SIZE_UL = 128 # for VAT(SSL) 32 | BATCH_SIZE_TEST = 100 # for VAT(SSL) 33 | 34 | tf.flags.DEFINE_string("dataset", "SVHN", "CIFAR10 / SVHN ") 35 | tf.flags.DEFINE_boolean("restore", False, "restore from the last check point") 36 | tf.flags.DEFINE_string("dir_logs", "./out/", "") 37 | tf.flags.DEFINE_string("file_ckpt", "", "") 38 | tf.flags.DEFINE_boolean("use_pi", True, "") 39 | tf.flags.DEFINE_boolean("use_vat", False, "") 40 | tf.flags.DEFINE_boolean("use_fgsm", True, "") 41 | tf.flags.DEFINE_boolean("use_virtual_cw", False, "") 42 | tf.flags.DEFINE_boolean("use_am_softmax", False, "") 43 | tf.flags.DEFINE_boolean("use_particle_barrier", False, "") 44 | 45 | FLAGS = tf.flags.FLAGS 46 | 47 | 48 | def set_condition(_use_pi, _use_vat, _use_fgsm, _use_virtual_cw, _use_am_softmax, _use_particle_barrier, dir_logs=FLAGS.dir_logs): 49 | if DO_TRAIN: 50 | pass 51 | else: 52 | #if dir_logs is None: 53 | # dir_logs = "out_%s_vat_%d__fgsm_%d__vcw_%d__ams_%d__pb_%d" % \ 54 | # (FLAGS.dataset, int(_use_pi), int(_use_vat), int(_use_fgsm), int(_use_virtual_cw), 55 | # int(_use_am_softmax), int(_use_particle_barrier)) 56 | FLAGS.dir_logs = dir_logs 57 | FLAGS.file_ckpt = os.path.join(FLAGS.dir_logs,"model.ckpt") 58 | FLAGS.use_pi = _use_pi 59 | FLAGS.use_vat = _use_vat 60 | FLAGS.use_fgsm = _use_fgsm 61 | FLAGS.use_virtual_cw = _use_virtual_cw 62 | FLAGS.use_am_softmax = _use_am_softmax 63 | FLAGS.use_particle_barrier = _use_particle_barrier 64 | print('[INFO] dir_logs was set as: ', FLAGS.dir_logs) 65 | return 66 | 67 | set_condition(FLAGS.use_pi, FLAGS.use_vat, FLAGS.use_fgsm, FLAGS.use_virtual_cw, FLAGS.use_am_softmax, FLAGS.use_particle_barrier) 68 | 69 | 70 | if not DO_TRAIN and not FLAGS.restore: 71 | print('[WARN] FLAGS.restore is set to True compulsorily') 72 | FLAGS.restore = True 73 | 74 | K = 10 75 | TEST_IDXES = [9] 76 | PATHS = ( [''], None) 77 | 78 | N_EPOCHS = 500 79 | N_PLOTS = 2000 80 | 81 | IS_AUGMENTATION_ENABLED = True 82 | IS_AUG_TRANS_TRUE = True 83 | IS_AUG_FLIP_TRUE = False 84 | IS_AUG_ROTATE_TRUE = False 85 | IS_AUG_COLOR_TRUE = False 86 | IS_AUG_CROP_TRUE = False 87 | IS_AUG_NOISE_TRUE = False 88 | 89 | # learning rate decay 90 | STARTER_LEARNING_RATE = 1e-3 91 | STARTER_LEARNING_RATE = 2e-4 # DCGAN 92 | #STARTER_LEARNING_RATE = 0.001 # VAT 93 | DECAY_AFTER = 2 94 | #DECAY_AFTER = 80 # VAT 95 | DECAY_INTERVAL = 2 96 | DECAY_FACTOR = 0.97 97 | 98 | # Pi 99 | PI_COOL_DOWN_START = 100 100 | PI_COOL_DOWN_DURATION = 350 101 | LAMBDA_PI_MIN = 1 102 | assert PI_COOL_DOWN_DURATION > 1, 'duration must be longer than 1 epoch since LAMBDA_PI would stay at stating point.' 103 | LAMBDA_PI = np.linspace(1, LAMBDA_PI_MIN, PI_COOL_DOWN_DURATION) 104 | 105 | 106 | IS_DO_ENABLE = False 107 | 108 | # Region Classifier 109 | N_PARTICLES_FOR_REGION_CLASSIFIER = 128 # number of neighbour points to be generated 110 | REGION_RADIUS = 0.3 111 | IS_GENERATE_PARTICLE_LOGIT_W_DROPOUT = True 112 | if IS_GENERATE_PARTICLE_LOGIT_W_DROPOUT: 113 | IS_DO_ENABLE = True 114 | REGION_RADIUS = 0.0 115 | 116 | if DO_REGION_CLASSIFY: 117 | if DO_TRAIN: 118 | pass 119 | sys.exit('Are you sure to start training w/ region classify ? If yes, comment out this line.') 120 | 121 | if BATCH_SIZE * N_PARTICLES_FOR_REGION_CLASSIFIER > 500: 122 | print('BATCH_SIZE / N_PARTICLES_FOR_REGION_CLASSIFIER =', BATCH_SIZE, N_PARTICLES_FOR_REGION_CLASSIFIER) 123 | sys.exit('Maybe too much.') 124 | 125 | 126 | # Particle Barrier 127 | IS_ORTHOGONALITY_ENABLE = True 128 | N_PARTICLES_FOR_BARRIER = 5 # number of neighbour points to be generated 129 | IS_SOFT_SHELL_ENABLE = False 130 | BARRIER_MODE = 'supervised' 131 | 132 | 133 | if FLAGS.dataset == "SVHN": 134 | STARTER_BARRIER_DEPTH_MAX = 0.01 135 | BARRIER_ACTIVATES_AFTER = 50 136 | BARRIER_ACTIVATES_AFTER = 0 137 | BARRIER_GROWTH_INTERVAL = 5 138 | BARRIER_GROWTH = 1 139 | 140 | elif FLAGS.dataset == "CIFAR10": 141 | STARTER_BARRIER_DEPTH_MAX = 0.01 142 | BARRIER_ACTIVATES_AFTER = 0 143 | BARRIER_GROWTH_INTERVAL = 5 144 | BARRIER_GROWTH = 1 145 | 146 | 147 | DIVERGENCE = 'least_square' 148 | #DIVERGENCE = 'kl_forward' 149 | #DIVERGENCE = 'kl_reverse' 150 | #DIVERGENCE = 'js' 151 | #DIVERGENCE = 'mmd' 152 | 153 | # SOFTMAX 154 | SOFTMAX_DEDUCTON = 0.35 # for MNIST 155 | #SOFTMAX_INVERSE_TEMPERATURE = 30 156 | SOFTMAX_INVERSE_TEMPERATURE = 1 157 | 158 | # VAT 159 | IS_RELAXED_KL_ENABLE = False 160 | 161 | # Regarding measuring the Distance to Decision Boundary 162 | DDB_N_DIRECTIONS = 10 163 | DDB_STEP = 0.01 164 | DDB_MAX = 0.5 165 | FILE_OF_DDB_DIRECTIONS = 'gxr3_directions_%d.npy'%(DDB_N_DIRECTIONS) 166 | 167 | # FGSM 168 | EPSILON_FGSM = 0.1 169 | 170 | # attack 171 | ADV_TARGET_CLASS = 0 172 | CW_CONFIDENCE = 20 173 | #CW_MAX_ITERATIONS = 1000 174 | CW_MAX_ITERATIONS = 10000 175 | N_BINARY_SEARCH = 10 176 | BOUND_BINARY_SEARCH = (10**-6, 1) # 1e-06 177 | #BOUND_BINARY_SEARCH = (0.00001, 1) # 1e-06 178 | 179 | DIR_DATA = './data/%s.confidence_%s/'%(FLAGS.dataset, CW_CONFIDENCE) 180 | X_CLASSIFIED_CORRECTLY = DIR_DATA + 'x_classified_correctly.npy' 181 | Y_CLASSIFIED_CORRECTLY = DIR_DATA + 'y_classified_correctly.npy' 182 | X_ORIGINAL = DIR_DATA + 'x_original.npy' 183 | Y_ORIGINAL = DIR_DATA + 'y_original.npy' 184 | X_ADVERSARIAL = DIR_DATA + 'x_adversarial.npy' 185 | Y_ADVERSARIAL_TARGET = DIR_DATA + 'y_adversarial_target.npy' 186 | X_TRAIN_CLASSIFIED_CORRECTLY = DIR_DATA + 'x_train_classified_correctly.npy' 187 | Y_TRAIN_CLASSIFIED_CORRECTLY = DIR_DATA + 'y_train_classified_correctly.npy' 188 | X_TRAIN_ORIGINAL = DIR_DATA + 'x_train_original.npy' 189 | Y_TRAIN_ORIGINAL = DIR_DATA + 'y_train_original.npy' 190 | X_TRAIN_ADVERSARIAL = DIR_DATA + 'x_train_adversarial.npy' 191 | Y_TRAIN_ADVERSARIAL_TARGET = DIR_DATA + 'y_train_adversarial_target.npy' 192 | FILE_MODEL = './nn_robust_attacks/models/%s'%(FLAGS.dataset) 193 | 194 | # sanitize 195 | SANITIZER = 'AE' 196 | SANITIZER = 'PGD_X' 197 | SANITIZER = 'PGD_Z' 198 | 199 | 200 | LAMBDA_PREDICTION_VARIANCE = 0.1 201 | LAMBDA_RECONSTRUCTION = 1 202 | IS_LOSS_PREDICTION_VARIANCE_ENABLED = False 203 | IS_LOSS_RECONSTRUCTION_ENABLED = True 204 | 205 | STARTER_LEARNING_RATE_ADV = 1e-3 206 | STARTER_LEARNING_RATE_ADV = 0.2 207 | STARTER_LEARNING_RATE_ADV_GAN = 0.2 208 | 209 | N_EPOCHS_ADV = 1000 210 | N_EPOCHS_ADV = 500 211 | N_EPOCHS_ADV = 300 212 | N_EPOCHS_ADV = 100 213 | DECAY_AFTER_ADV = 300 214 | DECAY_INTERVAL_ADV = 100 215 | DECAY_FACTOR_ADV = 0.97 216 | 217 | -------------------------------------------------------------------------------- /vae/util/HandleImageDataNumpy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import os, sys, time, math 6 | 7 | IS_NHWC_or_1D = 'NHWC' # 'NHWC'/'1D' 8 | 9 | class HandleImageDataNumpy(object): 10 | 11 | def __init__(self, dataset, batch_size): 12 | self.dataset = dataset 13 | 14 | # image width,height 15 | if self.dataset == "MNIST": 16 | from tensorflow.examples.tutorials import mnist 17 | _h, _w, _c = 28,28,1 18 | img_size = _h*_w*_c # the canvas size 19 | _l = 10 20 | elif self.dataset == "CIFAR10": 21 | _h, _w, _c = 32,32,3 22 | img_size = _h*_w*_c 23 | _l = 10 24 | elif self.dataset == "SVHN": 25 | #import loadSVHNKingma as svhn 26 | #PCA_DIM = 768 27 | PCA_DIM = -1 # no compressed raw data 28 | #img_size = PCA_DIM # PCA 29 | _h, _w, _c = 32,32,3 30 | img_size = _h*_w*_c 31 | _l = 10 32 | elif self.dataset == "KaggleBreastHistology": 33 | _h, _w, _c = 50,50,3 34 | img_size = _h*_w*_c 35 | _l = 2 36 | elif self.dataset == "BreaKHis": 37 | _h, _w, _c = 460,700,3 38 | img_size = _h*_w*_c 39 | _l = 2 40 | elif self.dataset == "Kyoto2006": 41 | from loadKyoto2006 import loadKyoto2006 42 | _h, _w, _c = None,None,None 43 | img_size = None # dummy 44 | _l = 2 45 | else: sys.exit("invalid dataset") 46 | 47 | self.h = _h 48 | self.w = _w 49 | self.c = _c 50 | self.l = _l 51 | self.img_size = img_size 52 | self.batch_size = batch_size 53 | 54 | 55 | if self.dataset == "MNIST": 56 | PATH_OF_MNIST = "D:/data/img/MNIST/" 57 | data_directory = PATH_OF_MNIST 58 | if not os.path.exists(data_directory): os.makedirs(data_directory) 59 | mnist_datasets = mnist.input_data.read_data_sets(data_directory, one_hot=True) 60 | dataset_train, dataset_test = mnist_datasets.train, mnist_datasets.test # binarized (0-1) mnist data 61 | 62 | n_examples_train = dataset_train.images.shape[0] 63 | n_examples_test = dataset_test.images.shape[0] 64 | 65 | elif self.dataset == "CIFAR10": 66 | #from cifar10 import load_cifar10 67 | from keras.datasets import cifar10 68 | (data_train, labels_train), (data_test, labels_test) = cifar10.load_data() # [0-255] integer 69 | data_train = data_train / 255. 70 | data_test = data_test / 255. 71 | 72 | if IS_NHWC_or_1D == '1D': 73 | data_train, data_test = data_train.reshape((-1, img_size)), data_test.reshape((-1, img_size)) # NHWC to 1d 74 | data_train, data_test = data_train.astype(np.float32), data_test.astype(np.float32) 75 | labels_train, labels_test = labels_train.reshape((-1, )), labels_test.reshape((-1, )) # flatten 76 | 77 | # if normalized or zca-ed one is preferable, 78 | #data_train, labels_train, data_test, labels_test = cifar10.loadCIFAR10( PATH_OF_CIFAR10, use_cache=True) 79 | labels_train = self._one_hot_encoded(labels_train, 10) 80 | labels_test = self._one_hot_encoded(labels_test, 10) 81 | n_examples_train = len(data_train) 82 | n_examples_test = len(data_test) 83 | 84 | elif self.dataset == "SVHN": 85 | from svhn import load_svhn, NUM_EXAMPLES_TRAIN, NUM_EXAMPLES_TEST 86 | # data_train.shape is (604388,3072) w/ extra and (73257,3072) w/o extra 87 | #data_train, labels_train, data_test, labels_test = svhn.loadSVHN(cutoffdim=PCA_DIM, use_cache=False, use_extra=False) 88 | (data_train, labels_train), (data_test, labels_test) = load_svhn() 89 | labels_train = self._one_hot_encoded(labels_train, 10) 90 | labels_test = self._one_hot_encoded(labels_test, 10) 91 | """ 92 | n_examples_train = NUM_EXAMPLES_TRAIN 93 | n_examples_test = NUM_EXAMPLES_TEST 94 | 95 | """ 96 | n_examples_train = (data_train.shape[0]//self.batch_size) * self.batch_size # discard residual 97 | data_train, labels_train = data_train[0:n_examples_train, :], labels_train[0:n_examples_train, :] 98 | 99 | n_examples_test = data_test.shape[0]//self.batch_size * self.batch_size 100 | data_test, labels_test = data_test[0:n_examples_test, :], labels_test[0:n_examples_test, :] 101 | 102 | elif self.dataset == "KaggleBreastHistology": 103 | from HandleIIDDataTFRecord import HandleIIDDataTFRecord 104 | K = 10 105 | TEST_IDXES = [9] 106 | PATHS = ( ['D:/data/img/KaggleBreastHistology'], None) 107 | d = HandleIIDDataTFRecord( self.dataset, self.batch_size, K, PATHS, is_debug=False) 108 | 109 | (data_train, labels_train), (data_test, labels_test) = d.get_ndarrays(TEST_IDXES) 110 | labels_train = self._one_hot_encoded(labels_train, self.l) 111 | labels_test = self._one_hot_encoded(labels_test, self.l) 112 | 113 | 114 | #print('x:', data_train[0]) 115 | #print('y:', labels_train[0]) 116 | #sys.exit('kokomade') 117 | n_examples_train = len(data_train) 118 | n_examples_test = len(data_test) 119 | 120 | n_examples_train = (data_train.shape[0]//self.batch_size) * self.batch_size # discard residual 121 | n_examples_test = data_test.shape[0]//self.batch_size * self.batch_size 122 | 123 | elif self.dataset == "BreaKHis": 124 | from HandleIIDDataTFRecord import HandleIIDDataTFRecord 125 | K = 10 126 | TEST_IDXES = [9] 127 | PATHS = ( ['D:/data/img/BreaKHis/BreaKHis_v1/histology_slides/breast'], None) 128 | d = HandleIIDDataTFRecord( self.dataset, self.batch_size, K, PATHS, is_debug=False) 129 | 130 | (data_train, labels_train), (data_test, labels_test) = d.get_ndarrays(TEST_IDXES) 131 | labels_train = self._one_hot_encoded(labels_train, self.l) 132 | labels_test = self._one_hot_encoded(labels_test, self.l) 133 | 134 | n_examples_train = len(data_train) 135 | n_examples_test = len(data_test) 136 | 137 | #n_examples_train = (data_train.shape[0]//self.batch_size) * self.batch_size # discard residual 138 | #data_train, labels_train = data_train[0:n_examples_train, :], labels_train[0:n_examples_train, :] 139 | 140 | #n_examples_test = data_test.shape[0]//self.batch_size * self.batch_size 141 | #data_test, labels_test = data_test[0:n_examples_test, :], labels_test[0:n_examples_test, :] 142 | 143 | elif self.dataset == "Kyoto2006": 144 | data_train, labels_train = loadKyoto2006('train', use_sval=False, use_cache=True, as_onehot=True) 145 | data_test, labels_test = loadKyoto2006( 'test', use_sval=False, use_cache=True, as_onehot=True) 146 | 147 | print(data_train.shape, labels_train.shape) 148 | n_examples_train = (data_train.shape[0]//self.batch_size) * self.batch_size # discard residual 149 | data_train, labels_train = data_train[0:n_examples_train, :], labels_train[0:n_examples_train, :] 150 | 151 | n_examples_test = data_test.shape[0]//self.batch_size * self.batch_size 152 | data_test, labels_test = data_test[0:n_examples_test, :], labels_test[0:n_examples_test, :] 153 | 154 | self.img_size = data_train.shape[1] 155 | # ugly work waround for ImageInterface 156 | self.h, self.w, self.c = 1,1,self.img_size 157 | 158 | if self.dataset == "SVHN": 159 | pass 160 | else: 161 | assert(n_examples_train%self.batch_size ==0) 162 | assert(n_examples_test%self.batch_size ==0) 163 | 164 | if self.dataset == "MNIST": 165 | # following two properties are for MNIST. 166 | self.dataset_train = dataset_train 167 | self.dataset_test = dataset_test 168 | 169 | # bellow is just trial for crafting adv examples in eval.py 170 | self.data_train, self.labels_train = dataset_train.next_batch(55000) # x: (BATCH_SIZE x img_size) 171 | self.data_test, self.labels_test = dataset_test.next_batch(10000) # x: (BATCH_SIZE x img_size) 172 | 173 | else: 174 | self.data_train = data_train 175 | self.labels_train = labels_train 176 | self.data_test = data_test 177 | self.labels_test = labels_test 178 | 179 | #if IS_NHWC_or_1D == 'NHWC': 180 | # self.dataset_train = np.reshape( self.dataset_train, (self.batch_size, self.h, self.w, self.c)) 181 | # self.dataset_test = np.reshape( self.dataset_test, (self.batch_size, self.h, self.w, self.c)) 182 | 183 | self.n_examples_train = n_examples_train 184 | self.n_examples_test = n_examples_test 185 | 186 | self.n_batches_train = int( self.n_examples_train/self.batch_size ) 187 | self.n_batches_test = int( self.n_examples_test/self.batch_size ) 188 | print('n_examples_train:%d, n_batches_train:%d, n_batches_test:%d' % \ 189 | (self.n_examples_train, self.n_batches_train, self.n_batches_test)) 190 | 191 | # DataHandler 192 | def _get_a_batch(self, data, labels, i): 193 | # data and labels should be (BATCH_SIZE, ..., ...) 194 | _batch_size = self.batch_size 195 | batch_data = data[ i*_batch_size:(i+1)*_batch_size] 196 | batch_labels = labels[ i*_batch_size:(i+1)*_batch_size] 197 | return batch_data, batch_labels 198 | 199 | def _get_a_batch_old(self, data, labels, step): 200 | #print(data.shape) 201 | _batch_size = self.batch_size 202 | size = labels.shape[0] 203 | offset = (step * _batch_size) % (size - _batch_size) 204 | batch_data = data[offset:(offset + _batch_size), ...] 205 | batch_labels = labels[offset:(offset + _batch_size)] 206 | return batch_data, batch_labels 207 | 208 | def _one_hot_encoded(self, class_numbers, num_classes=None): 209 | # https://github.com/Hvass-Labs/TensorFlow-Tutorials/blob/master/06_CIFAR-10.ipynb 210 | # Find the number of classes if None is provided. 211 | if num_classes is None: num_classes = np.max(class_numbers) - 1 212 | return np.eye(num_classes, dtype=float)[class_numbers] 213 | 214 | def get_next_batch(self, i, is_train): 215 | _batch_size = self.batch_size 216 | if is_train: 217 | if self.dataset == "MNIST": 218 | x,y = self.dataset_train.next_batch(_batch_size) # x: (BATCH_SIZE x img_size) 219 | else: 220 | x,y = self._get_a_batch(self.data_train, self.labels_train, i ) 221 | else: 222 | if self.dataset == "MNIST": 223 | x,y = self.dataset_test.next_batch(_batch_size) # x: (BATCH_SIZE x img_size) 224 | else: 225 | x,y = self._get_a_batch(self.data_test, self.labels_test, i ) 226 | 227 | if IS_NHWC_or_1D == 'NHWC': 228 | x = np.reshape( x, (_batch_size, self.h, self.w, self.c)) 229 | 230 | return x,y 231 | 232 | 233 | ############################ 234 | """ MISC """ 235 | ############################ 236 | class utils(object): 237 | def list2str(l): return ", ".join (map(str,l)) 238 | def list2mu(l,i, stepback=1): return np.mean(np.array(l[i])/ np.array(l[i-stepback])) 239 | 240 | if __name__ == '__main__': 241 | BATH_SIZE = 50 242 | DATASET = 'MNIST' 243 | DATASET = 'CIFAR10' 244 | DATASET = 'SVHN' 245 | DATASET = 'BreaKHis' 246 | d = HandleImageDataNumpy(DATASET, BATH_SIZE) 247 | _x, _y = d.get_next_batch(3, True) 248 | print(_x, _y) 249 | 250 | -------------------------------------------------------------------------------- /vae/util/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | """Routine for decoding the CIFAR-10 binary file format.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | import sys 22 | import tarfile 23 | 24 | import numpy as np 25 | from scipy import linalg 26 | import glob 27 | import pickle 28 | 29 | from six.moves import xrange # pylint: disable=redefined-builtin 30 | from six.moves import urllib 31 | 32 | import tensorflow as tf 33 | 34 | from dataset_utils import * 35 | 36 | DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' 37 | 38 | #tf.app.flags.DEFINE_string('data_dir', 'C:/Users/fx29351/Python/data/CIFAR10', 'where to store the dataset') 39 | #tf.app.flags.DEFINE_integer('num_labeled_examples', 4000, "The number of labeled examples") 40 | tf.app.flags.DEFINE_integer('num_valid_examples', 1000, "The number of validation examples") 41 | #tf.app.flags.DEFINE_integer('dataset_seed', 1, "dataset seed") 42 | FLAGS = tf.app.flags.FLAGS 43 | 44 | # Process images of this size. Note that this differs from the original CIFAR 45 | # image size of 32 x 32. If one alters this number, then the entire model 46 | # architecture will change and any model would need to be retrained. 47 | IMAGE_SIZE = 32 48 | 49 | N_LABELED = 4000 50 | DATASET_SEED = 1 51 | DATA_DIR = 'D:/data/img/CIFAR10' 52 | 53 | # Global constants describing the CIFAR-10 data set. 54 | NUM_CLASSES = 10 55 | NUM_EXAMPLES_TRAIN = 50000 56 | NUM_EXAMPLES_TEST = 10000 57 | 58 | def load_cifar10(): 59 | """Download and extract the tarball from Alex's website.""" 60 | dest_directory = DATA_DIR 61 | if not os.path.exists(dest_directory): 62 | os.makedirs(dest_directory) 63 | filename = DATA_URL.split('/')[-1] 64 | filepath = os.path.join(dest_directory, filename) 65 | if not os.path.exists(filepath): 66 | def _progress(count, block_size, total_size): 67 | sys.stdout.write('\r>> Downloading %s %.1f%%' % 68 | (filename, float(count * block_size) / 69 | float(total_size) * 100.0)) 70 | sys.stdout.flush() 71 | 72 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 73 | print() 74 | statinfo = os.stat(filepath) 75 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 76 | tarfile.open(filepath, 'r:gz').extractall(dest_directory) 77 | 78 | # Training set 79 | print("Loading training data...") 80 | train_images = np.zeros((NUM_EXAMPLES_TRAIN, 3 * 32 * 32), dtype=np.float32) 81 | train_labels = [] 82 | for i, data_fn in enumerate( 83 | sorted(glob.glob(DATA_DIR + '/cifar-10-batches-py/data_batch*'))): 84 | batch = unpickle(data_fn) 85 | train_images[i * 10000:(i + 1) * 10000] = batch['data'] 86 | train_labels.extend(batch['labels']) 87 | 88 | # geosada 170713 for generative model 89 | #train_images = (train_images - 127.5) / 255. 90 | # -> [0,1] 91 | train_images = train_images / 255. 92 | train_labels = np.asarray(train_labels, dtype=np.int64) 93 | 94 | rand_ix = np.random.permutation(NUM_EXAMPLES_TRAIN) 95 | train_images = train_images[rand_ix] 96 | train_labels = train_labels[rand_ix] 97 | 98 | print("Loading test data...") 99 | test = unpickle(DATA_DIR + '/cifar-10-batches-py/test_batch') 100 | test_images = test['data'].astype(np.float32) 101 | # geosada 170713 102 | #test_images = (test_images - 127.5) / 255. 103 | # -> [0,1] 104 | test_images = test_images / 255. 105 | test_labels = np.asarray(test['labels'], dtype=np.int64) 106 | 107 | # geosada 170713 for generative model 108 | """ 109 | print("Apply ZCA whitening") 110 | components, mean, train_images = ZCA(train_images) 111 | np.save('{}/components'.format(DATA_DIR), components) 112 | np.save('{}/mean'.format(DATA_DIR), mean) 113 | test_images = np.dot(test_images - mean, components.T) 114 | """ 115 | 116 | train_images = train_images.reshape( 117 | (NUM_EXAMPLES_TRAIN, 3, 32, 32)).transpose((0, 2, 3, 1)).reshape((NUM_EXAMPLES_TRAIN, -1)) 118 | test_images = test_images.reshape( 119 | (NUM_EXAMPLES_TEST, 3, 32, 32)).transpose((0, 2, 3, 1)).reshape((NUM_EXAMPLES_TEST, -1)) 120 | return (train_images, train_labels), (test_images, test_labels) 121 | 122 | 123 | def prepare_dataset(): 124 | (train_images, train_labels), (test_images, test_labels) = load_cifar10() 125 | 126 | dirpath = os.path.join(DATA_DIR, 'seed' + str(DATASET_SEED)) 127 | if not os.path.exists(dirpath): 128 | os.makedirs(dirpath) 129 | 130 | rng = np.random.RandomState(DATASET_SEED) 131 | rand_ix = rng.permutation(NUM_EXAMPLES_TRAIN) 132 | _train_images, _train_labels = train_images[rand_ix], train_labels[rand_ix] 133 | 134 | examples_per_class = int(N_LABELED / 10) 135 | labeled_train_images = np.zeros((N_LABELED, 3072), dtype=np.float32) 136 | labeled_train_labels = np.zeros((N_LABELED), dtype=np.int64) 137 | for i in xrange(10): 138 | ind = np.where(_train_labels == i)[0] 139 | labeled_train_images[i * examples_per_class:(i + 1) * examples_per_class] \ 140 | = _train_images[ind[0:examples_per_class]] 141 | labeled_train_labels[i * examples_per_class:(i + 1) * examples_per_class] \ 142 | = _train_labels[ind[0:examples_per_class]] 143 | _train_images = np.delete(_train_images, 144 | ind[0:examples_per_class], 0) 145 | _train_labels = np.delete(_train_labels, 146 | ind[0:examples_per_class]) 147 | 148 | rand_ix_labeled = rng.permutation(N_LABELED) 149 | labeled_train_images, labeled_train_labels = \ 150 | labeled_train_images[rand_ix_labeled], labeled_train_labels[rand_ix_labeled] 151 | 152 | convert_images_and_labels(labeled_train_images, 153 | labeled_train_labels, 154 | os.path.join(dirpath, 'labeled_train.tfrecords')) 155 | convert_images_and_labels(train_images, train_labels, 156 | os.path.join(dirpath, 'unlabeled_train.tfrecords')) 157 | convert_images_and_labels(test_images, 158 | test_labels, 159 | os.path.join(dirpath, 'test.tfrecords')) 160 | 161 | # Construct dataset for validation 162 | train_images_valid, train_labels_valid = \ 163 | labeled_train_images[FLAGS.num_valid_examples:], labeled_train_labels[FLAGS.num_valid_examples:] 164 | test_images_valid, test_labels_valid = \ 165 | labeled_train_images[:FLAGS.num_valid_examples], labeled_train_labels[:FLAGS.num_valid_examples] 166 | unlabeled_train_images_valid = np.concatenate( 167 | (train_images_valid, _train_images), axis=0) 168 | unlabeled_train_labels_valid = np.concatenate( 169 | (train_labels_valid, _train_labels), axis=0) 170 | convert_images_and_labels(train_images_valid, 171 | train_labels_valid, 172 | os.path.join(dirpath, 'labeled_train_val.tfrecords')) 173 | convert_images_and_labels(unlabeled_train_images_valid, 174 | unlabeled_train_labels_valid, 175 | os.path.join(dirpath, 'unlabeled_train_val.tfrecords')) 176 | convert_images_and_labels(test_images_valid, 177 | test_labels_valid, 178 | os.path.join(dirpath, 'test_val.tfrecords')) 179 | 180 | 181 | def inputs(batch_size=100, 182 | train=True, validation=False, 183 | shuffle=True, num_epochs=None): 184 | if validation: 185 | if train: 186 | filenames = ['labeled_train_val.tfrecords'] 187 | num_examples = N_LABELED - FLAGS.num_valid_examples 188 | else: 189 | filenames = ['test_val.tfrecords'] 190 | num_examples = FLAGS.num_valid_examples 191 | else: 192 | if train: 193 | filenames = ['labeled_train.tfrecords'] 194 | num_examples = N_LABELED 195 | else: 196 | filenames = ['test.tfrecords'] 197 | num_examples = NUM_EXAMPLES_TEST 198 | 199 | # geosada 170701 200 | filenames = [os.path.join('seed' + str(DATASET_SEED), filename) for filename in filenames] 201 | #filenames = ['C:/Users/fx29351/Python/data/CIFAR10/seed' + str(DATASET_SEED) + '/' + filename for filename in filenames] 202 | 203 | filename_queue = generate_filename_queue(filenames, DATA_DIR, num_epochs) 204 | image, label = read(filename_queue) 205 | image = transform(tf.cast(image, tf.float32)) if train else image 206 | return generate_batch([image, label], num_examples, batch_size, shuffle) 207 | 208 | 209 | def unlabeled_inputs(batch_size=100, 210 | validation=False, 211 | shuffle=True): 212 | if validation: 213 | filenames = ['unlabeled_train_val.tfrecords'] 214 | num_examples = NUM_EXAMPLES_TRAIN - FLAGS.num_valid_examples 215 | else: 216 | filenames = ['unlabeled_train.tfrecords'] 217 | num_examples = NUM_EXAMPLES_TRAIN 218 | 219 | # geosada 170701 220 | filenames = [os.path.join('seed' + str(DATASET_SEED), filename) for filename in filenames] 221 | #filenames = ['C:/Users/fx29351/Python/data/CIFAR10/seed' + str(DATASET_SEED) + '/' + filename for filename in filenames] 222 | filename_queue = generate_filename_queue(filenames, DATA_DIR) 223 | image, label = read(filename_queue) 224 | image = transform(tf.cast(image, tf.float32)) 225 | return generate_batch([image], num_examples, batch_size, shuffle) 226 | 227 | 228 | def main(argv): 229 | prepare_dataset() 230 | 231 | 232 | if __name__ == "__main__": 233 | tf.app.run() 234 | -------------------------------------------------------------------------------- /vae/util/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os, sys, pickle 3 | import numpy as np 4 | from scipy import linalg 5 | 6 | #FLAGS = tf.app.flags.FLAGS 7 | #tf.app.flags.DEFINE_bool('aug_trans', False, "") 8 | #tf.app.flags.DEFINE_bool('aug_flip', False, "") 9 | 10 | AUG_TRANS = False 11 | AUG_FLIP = False 12 | 13 | def unpickle(file): 14 | fp = open(file, 'rb') 15 | if sys.version_info.major == 2: 16 | data = pickle.load(fp) 17 | elif sys.version_info.major == 3: 18 | data = pickle.load(fp, encoding='latin-1') 19 | fp.close() 20 | return data 21 | 22 | 23 | def ZCA(data, reg=1e-6): 24 | mean = np.mean(data, axis=0) 25 | mdata = data - mean 26 | sigma = np.dot(mdata.T, mdata) / mdata.shape[0] 27 | U, S, V = linalg.svd(sigma) 28 | components = np.dot(np.dot(U, np.diag(1 / np.sqrt(S) + reg)), U.T) 29 | whiten = np.dot(data - mean, components.T) 30 | return components, mean, whiten 31 | 32 | 33 | def _int64_feature(value): 34 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 35 | 36 | 37 | def _bytes_feature(value): 38 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 39 | 40 | 41 | def convert_images_and_labels(images, labels, filepath): 42 | 43 | print('[DEBUG] inputs shape:', images.shape, labels.shape) # (4000, 3072) (4000,) 44 | num_examples = labels.shape[0] 45 | if images.shape[0] != num_examples: 46 | raise ValueError("Images size %d does not match label size %d." % 47 | (images.shape[0], num_examples)) 48 | print('Writing', filepath) 49 | writer = tf.python_io.TFRecordWriter(filepath) 50 | for index in range(num_examples): 51 | image = images[index].tolist() 52 | image_feature = tf.train.Feature(float_list=tf.train.FloatList(value=image)) 53 | #print('[DEBUG] image_feature:', image_feature) # float_list { value: xxx},...} 54 | example = tf.train.Example(features=tf.train.Features(feature={ 55 | 'height': _int64_feature(32), 56 | 'width': _int64_feature(32), 57 | 'depth': _int64_feature(3), 58 | 'label': _int64_feature(int(labels[index])), 59 | 'image': image_feature})) 60 | writer.write(example.SerializeToString()) 61 | writer.close() 62 | 63 | 64 | def read(filename_queue): 65 | reader = tf.TFRecordReader() 66 | print('filename_queue',filename_queue) 67 | _, serialized_example = reader.read(filename_queue) 68 | features = tf.parse_single_example( 69 | serialized_example, 70 | # Defaults are not specified since both keys are required. 71 | features={ 72 | 'image': tf.FixedLenFeature([3072], tf.float32), 73 | 'label': tf.FixedLenFeature([], tf.int64), 74 | }) 75 | 76 | # Convert label from a scalar uint8 tensor to an int32 scalar. 77 | image = features['image'] 78 | image = tf.reshape(image, [32, 32, 3]) 79 | label = tf.one_hot(tf.cast(features['label'], tf.int32), 10) 80 | return image, label 81 | 82 | 83 | def generate_batch( 84 | example, 85 | min_queue_examples, 86 | batch_size, shuffle): 87 | """ 88 | Arg: 89 | list of tensors. 90 | """ 91 | num_preprocess_threads = 1 92 | 93 | if shuffle: 94 | ret = tf.train.shuffle_batch( 95 | example, 96 | batch_size=batch_size, 97 | num_threads=num_preprocess_threads, 98 | capacity=min_queue_examples + 5 * batch_size, 99 | min_after_dequeue=min_queue_examples) 100 | else: 101 | ret = tf.train.batch( 102 | example, 103 | batch_size=batch_size, 104 | num_threads=num_preprocess_threads, 105 | allow_smaller_final_batch=True, 106 | capacity=min_queue_examples + 5 * batch_size) 107 | 108 | return ret 109 | 110 | 111 | def transform(image): 112 | image = tf.reshape(image, [32, 32, 3]) 113 | if AUG_TRANS or AUG_FLIP: 114 | print("augmentation") 115 | if AUG_TRANS: 116 | image = tf.pad(image, [[2, 2], [2, 2], [0, 0]]) 117 | image = tf.random_crop(image, [32, 32, 3]) 118 | if AUG_FLIP: 119 | image = tf.image.random_flip_left_right(image) 120 | return image 121 | 122 | 123 | def generate_filename_queue(filenames, data_dir, num_epochs=None): 124 | print("filenames in queue:", filenames) 125 | for i in range(len(filenames)): 126 | filenames[i] = os.path.join(data_dir, filenames[i]) 127 | return tf.train.string_input_producer(filenames, num_epochs=num_epochs) 128 | 129 | 130 | -------------------------------------------------------------------------------- /vae/util/losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import sys, os 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/..') 5 | import config as c 6 | #np.set_printoptions(threshold=np.inf) 7 | 8 | 9 | """ VAT hyper params """ 10 | if True: 11 | XI = 10 # small constant for finite difference 12 | EP = 1.0 # norm length for (virtual) adversarial training 13 | else: 14 | # orginal values in https://github.com/takerum/vat_tf/blob/master/vat.py 15 | XI = 1e-6 # small constant for finite difference 16 | EP = 8.0 # norm length for (virtual) adversarial training 17 | 18 | N_POWER_ITER = 1 # the number of power iterations 19 | 20 | CONFIDENCE = 0.2 21 | 22 | eps = 1e-8 23 | 24 | class LossFunctions(object): 25 | 26 | def __init__(self, layers, dataset, encoder=None): 27 | 28 | self.ls = layers 29 | self.d = dataset 30 | self.encoder = encoder 31 | #self.reconst_pixel_log_stdv = tf.get_variable("reconst_pixel_log_stdv", initializer=tf.constant(0.0)) 32 | 33 | def get_loss_classification(self, logit, y, class_weights=None, gamma=0.0 ): 34 | 35 | loss = self._ce(logit, y) 36 | accur = self.get_accuracy(logit, y, gamma) 37 | return loss, accur 38 | 39 | def get_loss_regression(self, logit, y): 40 | logit = tf.reshape( logit, [-1]) 41 | y = tf.reshape( y, [-1]) 42 | loss = tf.sqrt(tf.reduce_mean(tf.square(tf.subtract(logit, y))) + eps) 43 | #loss = tf.reduce_mean(tf.square(tf.subtract(logit, y))) 44 | #loss = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(logit, y))) + eps) 45 | return loss 46 | 47 | def get_loss_pxz(self, x_reconst, x_original, pxz): 48 | if pxz == 'Bernoulli': 49 | #loss = tf.reduce_mean( tf.reduce_sum(self._binary_crossentropy(x_original, x_reconst),1)) # reconstruction term 50 | loss = tf.reduce_mean( self._binary_crossentropy(x_original, x_reconst)) # reconstruction term 51 | elif pxz == 'LeastSquare': 52 | x_reconst = tf.reshape( x_reconst, [-1]) 53 | x_original = tf.reshape( x_original, [-1]) 54 | #loss = tf.sqrt(tf.square(tf.reduce_mean(tf.subtract(x_original, x_reconst))) + eps) 55 | loss = tf.sqrt(tf.reduce_mean(tf.square(tf.subtract(x_original, x_reconst))) + eps) 56 | elif pxz == 'PixelSoftmax': 57 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=x_reconst, labels=tf.cast(x_original, dtype=tf.int32))) / (self.d.img_size * 256) 58 | elif pxz == 'DiscretizedLogistic': 59 | loss = -tf.reduce_mean( self._discretized_logistic(x_reconst, x_original)) 60 | else: 61 | sys.exit('invalid argument') 62 | return loss 63 | 64 | def _binary_crossentropy(self, t,o): 65 | t = tf.reshape( t, (-1, self.d.img_size)) 66 | o = tf.reshape( o, (-1, self.d.img_size)) 67 | return -tf.reduce_sum((t*tf.log(o+eps) + (1.0-t)*tf.log(1.0-o+eps)), axis=1) 68 | 69 | def _discretized_logistic(self, x_reconst, x_original, binsize=1/256.0): 70 | # https://github.com/openai/iaf/blob/master/tf_utils/ 71 | scale = tf.exp(self.reconst_pixel_log_stdv) 72 | x_original = (tf.floor(x_original / binsize) * binsize - x_reconst) / scale 73 | 74 | logp = tf.log(tf.sigmoid(x_original + binsize / scale) - tf.sigmoid(x_original) + eps) 75 | 76 | shape = x_reconst.get_shape().as_list() 77 | if len(shape) == 2: # 1d 78 | indices = (1,2,3) 79 | elif len(shape) == 4: # cnn as NHWC 80 | indices = (1) 81 | else: 82 | raise ValueError('shape of x is unexpected') 83 | 84 | return tf.reduce_sum(logp, indices) 85 | 86 | def get_logits_variance(self, z): 87 | 88 | """ z: logits (batch_size, n_mc_sampling, n_class)""" 89 | 90 | def tf_cov(x): 91 | x = tf.squeeze(x) # -> (_n, n_class) 92 | mean_x = tf.reduce_mean(x, axis=0, keepdims=True) 93 | mx = tf.matmul(tf.transpose(mean_x), mean_x) 94 | vx = tf.matmul(tf.transpose(x), x)/tf.cast(tf.shape(x)[0], tf.float32) 95 | cov_xx = vx - mx 96 | return cov_xx 97 | 98 | def mean_diag_cov(x): 99 | # after upgrade tf 100 | #cov = tf.covariance(x) # -> (_n, _n) 101 | cov = tf_cov(x) # -> (_n, _n) 102 | eigenvalues,_ = tf.linalg.eigh(cov) 103 | o = tf.reduce_mean( eigenvalues) 104 | return o 105 | 106 | mean_diag_covs_z = tf.map_fn(mean_diag_cov, z) 107 | return tf.reduce_mean(mean_diag_covs_z) 108 | 109 | def get_loss_vae(self, dim_z, mu,logsigma, _lambda=1.0 ): 110 | 111 | """ KL( z_L || N(0,I)) """ 112 | #mu, logsigma = mu_logsigma 113 | sigma = tf.exp(logsigma) 114 | sigma2 = tf.square(sigma) 115 | 116 | kl = 0.5*tf.reduce_sum( (tf.square(mu) + sigma2) - 2*logsigma, 1) - dim_z*.5 117 | return tf.reduce_mean( tf.maximum(_lambda, kl )) 118 | 119 | def get_loss_kl(self, m, _lambda=1.0 ): 120 | 121 | L = m.L 122 | Z_SIZES = m.Z_SIZES 123 | 124 | """ KL divergence KL( q(z_l) || p(z_0)) at each lyaer, where p(z_0) is set as N(0,I) """ 125 | Lzs1 = [0]*L 126 | 127 | """ KL( q(z_l) || p(z_l)) to monitor the activities of latent variable units at each layer 128 | as Fig.4 in http://papers.nips.cc/paper/6275-ladder-variational-autoencoders.pdf """ 129 | Lzs2 = [0]*L 130 | 131 | for l in range(L): 132 | d_mu, d_logsigma = m.d_mus[l], m.d_logsigmas[l] 133 | p_mu, p_logsigma = m.p_mus[l], m.p_logsigmas[l] 134 | 135 | d_sigma = tf.exp(d_logsigma) 136 | p_sigma = tf.exp(p_logsigma) 137 | d_sigma2, p_sigma2 = tf.square(d_sigma), tf.square(p_sigma) 138 | 139 | kl1 = 0.5*tf.reduce_sum( (tf.square(d_mu) + d_sigma2) - 2*d_logsigma, 1) - Z_SIZES[l]*.5 140 | kl2 = 0.5*tf.reduce_sum( (tf.square(d_mu - p_mu) + d_sigma2)/p_sigma2 - 2*tf.log((d_sigma/p_sigma) + eps), 1) - Z_SIZES[l]*.5 141 | 142 | Lzs1[l] = tf.reduce_mean( tf.maximum(_lambda, kl1 )) 143 | Lzs2[l] = tf.reduce_mean( kl2 ) 144 | 145 | """ use only KL-divergence at the top layer, KL( z_L || z_0) as loss cost for optimaization """ 146 | loss = Lzs1[-1] 147 | #loss += tf.add_n(Lzs2) 148 | return Lzs1, Lzs2, loss 149 | 150 | def get_loss_mmd(self, x, y): 151 | """ 152 | https://github.com/ShengjiaZhao/MMD-Variational-Autoencoder/blob/master/mmd_vae.ipynb 153 | """ 154 | 155 | def _kernel(x, y): 156 | x_size = tf.shape(x)[0] 157 | y_size = tf.shape(y)[0] 158 | dim = tf.shape(x)[1] 159 | tiled_x = tf.tile(tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])) 160 | tiled_y = tf.tile(tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])) 161 | return tf.exp(-tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)) 162 | 163 | x_kernel = _kernel(x, x) 164 | y_kernel = _kernel(y, y) 165 | xy_kernel = _kernel(x, y) 166 | return tf.reduce_mean(x_kernel) + tf.reduce_mean(y_kernel) - 2 * tf.reduce_mean(xy_kernel) 167 | 168 | def get_loss_kl_logit(self, logit_real, logit_virtual, mode, training_target_is): 169 | if training_target_is == 'real': 170 | logit_virtual = tf.stop_gradient(logit_virtual) 171 | elif training_target_is == 'virtual': 172 | logit_real = tf.stop_gradient(logit_real) 173 | else: 174 | raise ValueError('unexpected string was set in arg training_target_is.') 175 | #loss = self._kl_divergence_with_logit(logit_real, logit_virtual) 176 | if mode == 'kl_forward': 177 | loss = self._kl_divergence_with_logit(logit_real, logit_virtual) 178 | elif mode == 'kl_reverse': 179 | loss = self._kl_divergence_with_logit(logit_virtual, logit_real) 180 | elif mode == 'js': 181 | loss = self._js_divergence_with_logit(logit_real, logit_virtual) 182 | else: 183 | raise ValueError('unexpected string was set in arg mode.') 184 | return tf.identity(loss, name="loss_kl_logit") 185 | 186 | def get_loss_pi(self, x, logit_real, is_train): 187 | logit_real = tf.stop_gradient(logit_real) 188 | logit_virtual = self.encoder(x, is_train=is_train) 189 | if c.DIVERGENCE == 'mmd': 190 | loss = self.get_loss_mmd(logit_virtual, logit_real) 191 | elif c.DIVERGENCE == 'js': 192 | loss = self._js_divergence_with_logit(logit_real, logit_virtual) 193 | elif c.DIVERGENCE == 'least_square': 194 | loss = tf.sqrt(tf.reduce_mean(tf.square(tf.subtract(logit_real, logit_virtual))) + eps) 195 | else: 196 | sys.exit('invalid args: %s'%(c.DIVERGENCE)) 197 | return logit_real, logit_virtual, loss 198 | 199 | def get_loss_logit_diff_org(self, logit, y_true): 200 | 201 | # [ToDo] by replacing y_true with y_pred it will turn to be unsupervised way 202 | 203 | real = tf.reduce_sum((y_true)*logit,1) 204 | other = tf.reduce_max((1-y_true)*logit - (y_true*10000),1) 205 | 206 | IS_TARGETED_ATTACK = False 207 | if IS_TARGETED_ATTACK: 208 | # if targetted, optimize for making the other class most likely 209 | loss = tf.maximum(0.0, other-real+CONFIDENCE) 210 | else: 211 | # if untargeted, optimize for making this class least likely. 212 | loss = tf.maximum(0.0, real-other+CONFIDENCE) 213 | print('loss:', loss) 214 | return loss 215 | 216 | def get_loss_logit_diff(self, logit_real, logit_virtual, y_true): 217 | 218 | # [ToDo] by replacing y_true with y_pred it will turn to be unsupervised way 219 | 220 | real1 = tf.reduce_mean((y_true)*logit_real,1) 221 | real2 = tf.reduce_mean((y_true)*logit_virtual,1) 222 | 223 | if c.IS_RELAXED_KL_ENABLE: 224 | loss = self._kl_divergence_with_logit((y_true)*logit_real, (y_true)*logit_virtual) 225 | else: 226 | loss = tf.sqrt(tf.reduce_mean(tf.abs( (y_true)*(logit_virtual - logit_real ) )) + eps) 227 | """ 228 | #loss = tf.abs(real1 - real2) 229 | #loss = tf.reduce_mean((y_true)*(logit_virtual - logit_real ),1) 230 | loss = tf.reduce_mean((y_true)*(tf.abs(logit_virtual - logit_real )),1) 231 | #real1 = tf.nn.softmax(real1) 232 | #real2 = tf.nn.softmax(real2) 233 | #loss = tf.maximum(0.0, real1 - real2) 234 | #return (loss * 10000), real1, real2 235 | """ 236 | return loss, real1, real2 237 | 238 | def get_loss_virtual_cw(self, x, logit_real, y_true, is_train): 239 | r_vadv = self._generate_virtual_cw_perturbation(x, logit_real, y_true, is_train ) 240 | #print(logit_real, r_vadv) 241 | logit_real = tf.stop_gradient(logit_real) 242 | logit_virtual = self.encoder(x + r_vadv, is_train=is_train) 243 | loss, real1, real2 = self.get_loss_logit_diff(logit_real, logit_virtual, y_true) 244 | return tf.identity(loss, name="vcw_loss"), real1, real2 245 | 246 | def get_loss_vat(self, x, logit_real, is_train, y=None): 247 | r_vadv = self._generate_virtual_adversarial_perturbation(x, logit_real, is_train, y ) 248 | #print(logit_real, r_vadv) 249 | logit_real = tf.stop_gradient(logit_real) 250 | logit_virtual = self.encoder(x + r_vadv, is_train=is_train, y=y) 251 | 252 | if c.DIVERGENCE == 'mmd': 253 | loss = self.get_loss_mmd(logit_virtual, logit_real) 254 | elif c.DIVERGENCE == 'kl_forward': 255 | loss = self._kl_divergence_with_logit(logit_virtual, logit_real) 256 | elif c.DIVERGENCE == 'kl_reverse': 257 | loss = self._kl_divergence_with_logit(logit_real, logit_virtual) 258 | else: 259 | sys.exit('invalid args: %s'%(c.DIVERGENCE)) 260 | return tf.identity(loss, name="vat_loss"), logit_real, logit_virtual 261 | 262 | def get_loss_fgsm(self, x, y, loss, is_train, is_fgsm=True, name="at_loss"): 263 | r_adv = self._generate_adversarial_perturbation(x, loss, is_fgsm) 264 | logit = self.encoder(x + r_adv, is_train=is_train) 265 | loss = self._ce(logit, y) 266 | return loss 267 | 268 | def _get_normalized_vector(self, d): 269 | 270 | shape = d.get_shape().as_list() 271 | if len(shape) == 2: # 1d 272 | indices = (1,2,3) 273 | elif len(shape) == 3: # time-major sequential data as (T, N, embedding dimension) 274 | indices = (2) 275 | elif len(shape) == 4: # cnn as NHWC 276 | indices = (1) 277 | else: 278 | raise ValueError('shape of d is unexpected: %s'%(shape)) 279 | 280 | d /= (1e-12 + tf.reduce_max(tf.abs(d), indices, keepdims=True)) 281 | d /= tf.sqrt(1e-6 + tf.reduce_sum(tf.pow(d, 2.0), indices, keepdims=True)) 282 | return d 283 | 284 | def _generate_virtual_cw_perturbation(self, x, logit_real, y_true, is_train ): 285 | d = tf.random_normal(shape=tf.shape(x)) 286 | 287 | for _ in range(N_POWER_ITER): 288 | d = XI * self._get_normalized_vector(d) 289 | logit_virtual = self.encoder(x + d, is_train=is_train) 290 | dist, _, _ = self.get_loss_logit_diff( logit_real, logit_virtual, y_true) 291 | grad = tf.gradients(dist, [d], aggregation_method=2)[0] 292 | d = tf.stop_gradient(grad) 293 | 294 | return EP * self._get_normalized_vector(d) 295 | 296 | def _generate_virtual_adversarial_perturbation(self, x, logit_real, is_train, y ): 297 | d = tf.random_normal(shape=tf.shape(x)) 298 | 299 | for _ in range(N_POWER_ITER): 300 | d = XI * self._get_normalized_vector(d) 301 | logit_virtual = self.encoder(x + d, is_train=is_train, y=y) 302 | dist = self._kl_divergence_with_logit(logit_real, logit_virtual) 303 | grad = tf.gradients(dist, [d], aggregation_method=2)[0] 304 | d = tf.stop_gradient(grad) 305 | 306 | return EP * self._get_normalized_vector(d) 307 | 308 | def _generate_adversarial_perturbation(self, x, loss, is_fgsm): 309 | grad = tf.gradients(loss, [x], aggregation_method=2)[0] 310 | grad = tf.stop_gradient(grad) 311 | norm = self._get_normalized_vector(grad) 312 | if is_fgsm: 313 | return c.EPSILON_FGSM * tf.sign(norm) 314 | else: 315 | return c.EPSILON_FGSM * norm 316 | 317 | """ https://github.com/takerum/vat_tf/blob/master/layers.py """ 318 | def _ce(self, logit, y): 319 | unweighted_losses = tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=y) 320 | #if self.d.class_weights is None: 321 | if not hasattr( self.d, "class_weights"): 322 | return tf.reduce_mean(unweighted_losses) 323 | else: 324 | """ https://stackoverflow.com/questions/44560549/unbalanced-data-and-weighted-cross-entropy """ 325 | weights = tf.reduce_sum(self.d.class_weights * y, axis=1) 326 | weighted_losses = unweighted_losses * weights 327 | return tf.reduce_mean(weighted_losses) 328 | 329 | def get_accuracy(self, logit, y, gamma=0.0): 330 | pred = tf.argmax(logit, 1) 331 | true = tf.argmax(y, 1) 332 | return tf.reduce_mean(tf.to_float(tf.equal(pred, true))) 333 | 334 | def get_accuracy_w_rejection(self, logit, y, gamma=0.0): 335 | 336 | """ gamma: confidence threshold """ 337 | prob = tf.reduce_max( tf.nn.softmax(logit), 1) 338 | pred = tf.cast( tf.argmax(logit, 1), tf.int32) 339 | true = tf.cast( tf.argmax(y, 1), tf.int32) 340 | is_hit = tf.cast( tf.equal(pred, true), tf.bool) 341 | accr = tf.reduce_mean(tf.to_float(is_hit)) 342 | 343 | """ accuracy with rejecting unconfident examples """ 344 | 345 | """ [ToDo] replace with tf.bitwise.invert after upgrading to TF 1.4 """ 346 | cond = tf.greater(prob, gamma) 347 | cond_inv = tf.less_equal(prob, gamma) 348 | idxes = tf.reshape( (tf.where( cond )), [-1]) 349 | idxes_inv = tf.reshape( (tf.where( cond_inv )), [-1]) 350 | n_examples = tf.size(pred) 351 | n_rejected = n_examples - tf.size(idxes) 352 | 353 | pred_confident = tf.gather( pred, idxes) 354 | true_confident = tf.gather( true, idxes) 355 | accr_limited_in_w_confidence = tf.reduce_mean(tf.to_float(tf.equal(pred_confident, true_confident))) 356 | accr_w_confidence = tf.reduce_sum(tf.to_float(tf.equal(pred_confident, true_confident))) / tf.to_float(n_examples) 357 | 358 | """ info about error examples """ 359 | cond_error = tf.not_equal(pred, true) 360 | idxes_error = tf.reshape( (tf.where( cond_error )), [-1]) 361 | pred_error = tf.gather( pred, idxes_error) 362 | true_error = tf.gather( true, idxes_error) 363 | 364 | 365 | o = dict() 366 | o['accur'] = (accr, accr_limited_in_w_confidence, accr_w_confidence) 367 | o['n'] = (n_examples, n_rejected) 368 | o['error'] = (pred_error, true_error) 369 | o['data'] = (pred, true, prob, is_hit) 370 | 371 | return o 372 | 373 | 374 | def _logsoftmax(self, x): 375 | xdev = x - tf.reduce_max(x, 1, keepdims=True) 376 | lsm = xdev - tf.log(tf.reduce_sum(tf.exp(xdev), 1, keepdims=True)) 377 | return lsm 378 | 379 | def _kl_divergence_with_logit(self, q_logit, p_logit): 380 | q = tf.nn.softmax(q_logit) 381 | qlogq = tf.reduce_mean(tf.reduce_sum(q * self._logsoftmax(q_logit), 1)) 382 | qlogp = tf.reduce_mean(tf.reduce_sum(q * self._logsoftmax(p_logit), 1)) 383 | return qlogq - qlogp 384 | 385 | def _js_divergence_with_logit(self, q_logit, p_logit): 386 | q = tf.nn.softmax(q_logit) 387 | p = tf.nn.softmax(p_logit) 388 | m = (q + p)/2 389 | qlogq = tf.reduce_mean(tf.reduce_sum(q * self._logsoftmax(q_logit), 1)) 390 | plogp = tf.reduce_mean(tf.reduce_sum(p * self._logsoftmax(p_logit), 1)) 391 | qlogm = tf.reduce_mean(tf.reduce_sum(q * tf.log(m + eps))) 392 | plogm = tf.reduce_mean(tf.reduce_sum(p * tf.log(m + eps))) 393 | return (qlogq + plogp - qlogm - plogm)/2 394 | 395 | def get_loss_entropy_yx(self, logit): 396 | p = tf.nn.softmax(logit) 397 | return -tf.reduce_mean(tf.reduce_sum(p * self._logsoftmax(logit), 1)) 398 | -------------------------------------------------------------------------------- /vae/util/svhn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | from scipy.io import loadmat 8 | 9 | import numpy as np 10 | from scipy import linalg 11 | import glob 12 | import pickle 13 | 14 | from six.moves import xrange # pylint: disable=redefined-builtin 15 | from six.moves import urllib 16 | 17 | import tensorflow as tf 18 | from dataset_utils import * 19 | 20 | DATA_URL_TRAIN = 'http://ufldl.stanford.edu/housenumbers/train_32x32.mat' 21 | DATA_URL_TEST = 'http://ufldl.stanford.edu/housenumbers/test_32x32.mat' 22 | 23 | N_LABELED = 4000 24 | N_LABELED = 73257 25 | DATASET_SEED = 1 26 | DATA_DIR = '/data/img/SVHN' 27 | 28 | FLAGS = tf.app.flags.FLAGS 29 | #tf.app.flags.DEFINE_string('data_dir', '/tmp/svhn', "") 30 | #tf.app.flags.DEFINE_integer('num_labeled_examples', 1000, "The number of labeled examples") 31 | tf.app.flags.DEFINE_integer('num_valid_examples', 1000, "The number of validation examples") 32 | #tf.app.flags.DEFINE_integer('dataset_seed', 1, "dataset seed") 33 | 34 | NUM_EXAMPLES_TRAIN = 73257 35 | NUM_EXAMPLES_TEST = 26032 36 | 37 | 38 | def maybe_download_and_extract(): 39 | if not os.path.exists(DATA_DIR): 40 | os.makedirs(DATA_DIR) 41 | filepath_train_mat = os.path.join(DATA_DIR, 'train_32x32.mat') 42 | filepath_test_mat = os.path.join(DATA_DIR, 'test_32x32.mat') 43 | print(filepath_train_mat) 44 | print(filepath_test_mat) 45 | if not os.path.exists(filepath_train_mat) or not os.path.exists(filepath_test_mat): 46 | def _progress(count, block_size, total_size): 47 | sys.stdout.write('\r>> Downloading %.1f%%' % (float(count * block_size) / float(total_size) * 100.0)) 48 | sys.stdout.flush() 49 | 50 | urllib.request.urlretrieve(DATA_URL_TRAIN, filepath_train_mat, _progress) 51 | urllib.request.urlretrieve(DATA_URL_TEST, filepath_test_mat, _progress) 52 | 53 | # Training set 54 | print("... loading training data") 55 | train_data = loadmat(DATA_DIR + '/train_32x32.mat') 56 | 57 | # geosada 170717 58 | #train_x = (-127.5 + train_data['X']) / 255. 59 | train_x = (train_data['X']) / 255. 60 | train_x = train_x.transpose((3, 0, 1, 2)) 61 | train_x = train_x.reshape([train_x.shape[0], -1]) 62 | train_y = train_data['y'].flatten().astype(np.int32) 63 | train_y[train_y == 10] = 0 64 | 65 | # Test set 66 | print("... loading testing data") 67 | test_data = loadmat(DATA_DIR + '/test_32x32.mat') 68 | # geosada 170717 69 | #test_x = (-127.5 + test_data['X']) / 255. 70 | test_x = (test_data['X']) / 255. 71 | test_x = test_x.transpose((3, 0, 1, 2)) 72 | test_x = test_x.reshape((test_x.shape[0], -1)) 73 | test_y = test_data['y'].flatten().astype(np.int32) 74 | test_y[test_y == 10] = 0 75 | 76 | print("... saving npy as cache") 77 | np.save('{}/train_images'.format(DATA_DIR), train_x) 78 | np.save('{}/train_labels'.format(DATA_DIR), train_y) 79 | np.save('{}/test_images'.format(DATA_DIR), test_x) 80 | np.save('{}/test_labels'.format(DATA_DIR), test_y) 81 | 82 | 83 | def load_svhn(_use_cache=True): 84 | 85 | if _use_cache: 86 | print("... loading data from npy cache") 87 | else: 88 | maybe_download_and_extract() 89 | 90 | # 91 | # returned shape: 92 | # images: (n, img_size) 93 | # labels: (n,) 94 | # 95 | train_images = np.load('{}/train_images.npy'.format(DATA_DIR)).astype(np.float32) 96 | train_labels = np.load('{}/train_labels.npy'.format(DATA_DIR)).astype(np.uint8) 97 | test_images = np.load('{}/test_images.npy'.format(DATA_DIR)).astype(np.float32) 98 | test_labels = np.load('{}/test_labels.npy'.format(DATA_DIR)).astype(np.uint8) 99 | return (train_images, train_labels), (test_images, test_labels) 100 | 101 | 102 | def prepare_dataset(): 103 | (train_images, train_labels), (test_images, test_labels) = load_svhn() 104 | dirpath = os.path.join(DATA_DIR, 'seed' + str(DATASET_SEED)) 105 | if not os.path.exists(dirpath): 106 | os.makedirs(dirpath) 107 | 108 | rng = np.random.RandomState(DATASET_SEED) 109 | rand_ix = rng.permutation(NUM_EXAMPLES_TRAIN) 110 | print(rand_ix) 111 | _train_images, _train_labels = train_images[rand_ix], train_labels[rand_ix] 112 | 113 | labeled_ind = np.arange(N_LABELED) 114 | labeled_train_images, labeled_train_labels = _train_images[labeled_ind], _train_labels[labeled_ind] 115 | _train_images = np.delete(_train_images, labeled_ind, 0) 116 | _train_labels = np.delete(_train_labels, labeled_ind, 0) 117 | convert_images_and_labels(labeled_train_images, 118 | labeled_train_labels, 119 | os.path.join(dirpath, 'labeled_train.tfrecords')) 120 | convert_images_and_labels(train_images, train_labels, 121 | os.path.join(dirpath, 'unlabeled_train.tfrecords')) 122 | convert_images_and_labels(test_images, 123 | test_labels, 124 | os.path.join(dirpath, 'test.tfrecords')) 125 | 126 | # Construct dataset for validation 127 | train_images_valid, train_labels_valid = labeled_train_images, labeled_train_labels 128 | test_images_valid, test_labels_valid = \ 129 | _train_images[:FLAGS.num_valid_examples], _train_labels[:FLAGS.num_valid_examples] 130 | unlabeled_train_images_valid = np.concatenate( 131 | (train_images_valid, _train_images[FLAGS.num_valid_examples:]), axis=0) 132 | unlabeled_train_labels_valid = np.concatenate( 133 | (train_labels_valid, _train_labels[FLAGS.num_valid_examples:]), axis=0) 134 | convert_images_and_labels(train_images_valid, 135 | train_labels_valid, 136 | os.path.join(dirpath, 'labeled_train_val.tfrecords')) 137 | convert_images_and_labels(unlabeled_train_images_valid, 138 | unlabeled_train_labels_valid, 139 | os.path.join(dirpath, 'unlabeled_train_val.tfrecords')) 140 | convert_images_and_labels(test_images_valid, 141 | test_labels_valid, 142 | os.path.join(dirpath, 'test_val.tfrecords')) 143 | 144 | 145 | def inputs(batch_size=100, 146 | train=True, validation=False, 147 | shuffle=True, num_epochs=None): 148 | if validation: 149 | if train: 150 | filenames = ['labeled_train_val.tfrecords'] 151 | num_examples = N_LABELED 152 | else: 153 | filenames = ['test_val.tfrecords'] 154 | num_examples = FLAGS.num_valid_examples 155 | else: 156 | if train: 157 | filenames = ['labeled_train.tfrecords'] 158 | num_examples = N_LABELED 159 | else: 160 | filenames = ['test.tfrecords'] 161 | num_examples = NUM_EXAMPLES_TEST 162 | 163 | filenames = [os.path.join('seed' + str(DATASET_SEED), filename) for filename in filenames] 164 | filename_queue = generate_filename_queue(filenames, DATA_DIR, num_epochs) 165 | image, label = read(filename_queue) 166 | image = transform(tf.cast(image, tf.float32)) if train else image 167 | return generate_batch([image, label], num_examples, batch_size, shuffle) 168 | 169 | 170 | def unlabeled_inputs(batch_size=100, 171 | validation=False, 172 | shuffle=True): 173 | if validation: 174 | filenames = ['unlabeled_train_val.tfrecords'] 175 | num_examples = NUM_EXAMPLES_TRAIN - FLAGS.num_valid_examples 176 | else: 177 | filenames = ['unlabeled_train.tfrecords'] 178 | num_examples = NUM_EXAMPLES_TRAIN 179 | 180 | filenames = [os.path.join('seed' + str(DATASET_SEED), filename) for filename in filenames] 181 | filename_queue = generate_filename_queue(filenames, data_dir=DATA_DIR) 182 | image, label = read(filename_queue) 183 | image = transform(tf.cast(image, tf.float32)) 184 | return generate_batch([image], num_examples, batch_size, shuffle) 185 | 186 | 187 | def main(argv): 188 | prepare_dataset() 189 | 190 | 191 | if __name__ == "__main__": 192 | tf.app.run() 193 | --------------------------------------------------------------------------------