├── README.md ├── define_path.sh ├── images ├── car_15_small.gif ├── car_sample.stl ├── furniture-render.gif └── morph.gif ├── main.py ├── model_pix2pix.py ├── model_sdfgan.py ├── ops.py ├── stream_freqsplit.py ├── test_all.sh ├── train_pix2pix.sh ├── train_sdfgan.sh └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # 3D-SDFGAN 2 | 3D Signed Distance Function Based Generative Adversarial Networks 3 | 4 | 5 | 6 | ![3D Model of Above Car Sample](/images/car_sample.stl) 7 | 8 | 9 | 10 | 11 | 12 | ## About this study 13 | This study seeks to generate realistic looking, mesh-based 3D models by using GANs. Training is based on [ShapeNetCore dataset](https://www.shapenet.org/) that has been post-processed into 64x64x64 signed distance function fields. More details about this study can be found in [this paper](https://arxiv.org/abs/1709.07581) 14 | 15 | ## Reproducing the results 16 | ### Collecting the dataset 17 | First, define the project directory to store the data and computed results (checkpoints, logs, samples). Change the project directory, job name and dataset to train in the file `define_path.sh`. These parameters will be called globally throughout the training and testing: 18 | ```bash 19 | # variables to define by user 20 | PROJ_DIR=/Path/To/Your/Project/Directoy 21 | JOB_NAME=job0-yourjobname 22 | DATASET=synset_02958343_car 23 | ``` 24 | run the script to export variables to the current shell: 25 | ``` 26 | sudo chmod +x *.sh 27 | ./define_path.sh 28 | ``` 29 | go to project directory and download data (~28.8G. Might take quite long.): 30 | ``` 31 | cd $PROJ_DIR 32 | wget http://island.me.berkeley.edu/data/data.tar.gz 33 | tar -xvf data.tar 34 | ``` 35 | go back to code directory to run frequency spliting code for a dataset (e.g. car dataset) for running pix2pix: 36 | ``` 37 | cd /Path/To/SDFGAN 38 | python stream_freqsplit.py ${PROJ_DIR}/data/synset_02958343_car ${PROJ_DIR}/data/synset_02958343_car_freqsplit 0 0 39 | ``` 40 | the above code will run the freqsplit algorithm on multiple threads. Trailing 0 and 0 defaults the code to use all available threads and convert all data within the dataaset. 41 | 42 | ### Running the sdfgan part of the code 43 | ``` 44 | ./train_sdfgan.sh 45 | ``` 46 | ### Running the pix2pix part of the code 47 | ``` 48 | ./train_pix2pix.sh 49 | ``` 50 | ### Generate test results 51 | ``` 52 | ./test_all.sh 53 | ``` 54 | ### Post processing for mesh 55 | Codes for postprocessing the results to create mesh is not including in this current repository. However, it can be easily achieved using a Marching-Cubes Algorithm. The mesh in the example above is further post-processed using 3 steps of Laplacian smoothing, quadratic mesh decimation by half, and hole-filling for the holes at the front and back as a result of boundary effects caused by the limitations of the bounding box. 56 | 57 | ## Open Source Code Credits 58 | Borrows code from the following repositories: 59 | * [DCGAN-tensorflow](http://carpedm20.github.io/) repository by Taehoon Kim 60 | * [Pix2pix-tensorflow](https://github.com/affinelayer/pix2pix-tensorflow) repository by affinelayer 61 | 62 | Relied on the following Open Source projects for 3D pre and post processing of 3D mesh / signed distance function: 63 | * [Libigl](https://github.com/libigl/libigl) interactive graphics library by Alec Jacobson et. al. 64 | * [Meshlab](http://www.meshlab.net/) for mesh rendering and minor post-processing. 65 | -------------------------------------------------------------------------------- /define_path.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # variables to define by user 4 | PROJ_DIR=/home/maxjiang/Codes/project_sdfgan/sdfgan_data 5 | JOB_NAME=job6-plane 6 | DATASET=synset_02958343_car 7 | 8 | # export variables 9 | export PROJ_DIR 10 | export JOB_NAME 11 | -------------------------------------------------------------------------------- /images/car_15_small.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/SDFGAN/8ad62e13505a5d6652699f737ff045834e0246e3/images/car_15_small.gif -------------------------------------------------------------------------------- /images/car_sample.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/SDFGAN/8ad62e13505a5d6652699f737ff045834e0246e3/images/car_sample.stl -------------------------------------------------------------------------------- /images/furniture-render.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/SDFGAN/8ad62e13505a5d6652699f737ff045834e0246e3/images/furniture-render.gif -------------------------------------------------------------------------------- /images/morph.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjiang93/SDFGAN/8ad62e13505a5d6652699f737ff045834e0246e3/images/morph.gif -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from model_sdfgan import SDFGAN 2 | from model_pix2pix import Pix2Pix 3 | from utils import * 4 | from ops import batch_lowpass, batch_mirr 5 | import shutil 6 | 7 | import tensorflow as tf 8 | 9 | # common training flags 10 | flags = tf.app.flags 11 | flags.DEFINE_boolean("is_train", False, "True for training, False for testing [False]") 12 | flags.DEFINE_string("model", "sdfgan", "Model to train. Choice of 'sdfgan' or 'pix2pix' [sdfgan]") 13 | flags.DEFINE_boolean("is_new", False, "True for training from scratch, deleting original file [False]") 14 | flags.DEFINE_integer("epoch", 100, "Epoch to train [100]") 15 | flags.DEFINE_float("d_learning_rate", 0.0002, "Learning rate of discrim for adam [0.0002]") 16 | flags.DEFINE_float("g_learning_rate", 0.0005, "Learning rate of gen for adam [0.0005]") 17 | flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]") 18 | flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]") 19 | flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]") 20 | flags.DEFINE_integer("image_depth", 64, "The size of sdf field to use. [64]") 21 | flags.DEFINE_integer("image_height", None, "The size of sdf to use. If None, same value as image_depth [None]") 22 | flags.DEFINE_integer("image_width", None, "The size of sdf to use. If None, same value as image_depth [None]") 23 | flags.DEFINE_integer("c_dim", 1, "Number of channels. [1]") 24 | flags.DEFINE_string("dataset", "shapenet", "The name of dataset [shapenet]") 25 | flags.DEFINE_string("input_fname_pattern", "*.npy", "Glob pattern of filename of input sdf [*]") 26 | flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]") 27 | flags.DEFINE_string("dataset_dir", "data", "Directory name to read the input training data [data]") 28 | flags.DEFINE_string("log_dir", "logs", "Directory name to save the log files [logs]") 29 | flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]") 30 | flags.DEFINE_integer("num_gpus", 1, "Number of GPUs to use [1]") 31 | flags.DEFINE_integer("random_seed", 1, "Random seed [1]") 32 | flags.DEFINE_integer("save_interval", 200, "Interval in steps for saving model and samples [200]") 33 | 34 | # pix2pix specific training flags 35 | flags.DEFINE_integer("gan_weight", 1, "GAN weight in generator loss function. [1]") 36 | flags.DEFINE_integer("l1_weight", 100, "L1 weight in generator loss function. [100]") 37 | 38 | # testing flags 39 | flags.DEFINE_integer("sample_num", 16, "Number of samples. [16]") 40 | flags.DEFINE_string("test_from_input_path", None, "Path to test inputs. [None]") 41 | 42 | FLAGS = flags.FLAGS 43 | 44 | def main(_): 45 | pp.pprint(flags.FLAGS.__flags) 46 | 47 | if FLAGS.image_height is None: 48 | FLAGS.image_height = FLAGS.image_depth 49 | if FLAGS.image_width is None: 50 | FLAGS.image_width = FLAGS.image_depth 51 | 52 | if FLAGS.is_new: 53 | if os.path.exists(FLAGS.checkpoint_dir): 54 | shutil.rmtree(FLAGS.checkpoint_dir) 55 | if os.path.exists(FLAGS.sample_dir): 56 | shutil.rmtree(FLAGS.sample_dir) 57 | if os.path.exists(FLAGS.log_dir): 58 | shutil.rmtree(FLAGS.log_dir) 59 | 60 | if not os.path.exists(FLAGS.checkpoint_dir): 61 | os.makedirs(FLAGS.checkpoint_dir) 62 | if not os.path.exists(FLAGS.sample_dir): 63 | os.makedirs(FLAGS.sample_dir) 64 | 65 | # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333) 66 | run_config = tf.ConfigProto() 67 | run_config.gpu_options.allow_growth = True 68 | 69 | # set random seed 70 | tf.set_random_seed(FLAGS.random_seed) 71 | np.random.seed(FLAGS.random_seed) 72 | 73 | init_sdfgan = lambda sess: SDFGAN( 74 | sess, 75 | image_depth=FLAGS.image_depth, 76 | image_height=FLAGS.image_height, 77 | image_width=FLAGS.image_width, 78 | batch_size=FLAGS.batch_size, 79 | sample_num=FLAGS.batch_size, 80 | c_dim=FLAGS.c_dim, 81 | dataset_name=FLAGS.dataset, 82 | input_fname_pattern=FLAGS.input_fname_pattern, 83 | checkpoint_dir=FLAGS.checkpoint_dir, 84 | dataset_dir=FLAGS.dataset_dir, 85 | log_dir=FLAGS.log_dir, 86 | sample_dir=FLAGS.sample_dir, 87 | num_gpus=FLAGS.num_gpus, 88 | save_interval=FLAGS.save_interval) 89 | 90 | init_pix2pix = lambda sess: Pix2Pix( 91 | sess, 92 | image_depth=FLAGS.image_depth, 93 | image_height=FLAGS.image_height, 94 | image_width=FLAGS.image_width, 95 | batch_size=FLAGS.batch_size, 96 | sample_num=FLAGS.sample_num, 97 | gan_weight=FLAGS.gan_weight, 98 | l1_weight=FLAGS.l1_weight, 99 | c_dim=FLAGS.c_dim, 100 | dataset_name=FLAGS.dataset, 101 | input_fname_pattern=FLAGS.input_fname_pattern, 102 | checkpoint_dir=FLAGS.checkpoint_dir, 103 | dataset_dir=FLAGS.dataset_dir, 104 | log_dir=FLAGS.log_dir, 105 | sample_dir=FLAGS.sample_dir, 106 | num_gpus=FLAGS.num_gpus, 107 | save_interval=FLAGS.save_interval) 108 | 109 | # Train SDFGAN 110 | if FLAGS.model == 'sdfgan' and FLAGS.is_train: 111 | with tf.Session(config=run_config) as sess: 112 | sdfgan = init_sdfgan(sess) 113 | show_all_variables() 114 | if not os.path.exists(os.path.join(FLAGS.sample_dir, "sdfgan_sample")): 115 | os.makedirs(os.path.join(FLAGS.sample_dir, "sdfgan_sample")) 116 | if not os.path.exists(os.path.join(FLAGS.sample_dir, "sdfgan_log")): 117 | os.makedirs(os.path.join(FLAGS.log_dir, "sdfgan_log")) 118 | sdfgan.train(FLAGS) 119 | 120 | # Train Pix2Pix 121 | elif FLAGS.model == 'pix2pix' and FLAGS.is_train: 122 | with tf.Session(config=run_config) as sess: 123 | pix2pix = init_pix2pix(sess) 124 | show_all_variables() 125 | if not os.path.exists(os.path.join(FLAGS.sample_dir, "pix2pix_sample")): 126 | os.makedirs(os.path.join(FLAGS.sample_dir, "pix2pix_sample")) 127 | if not os.path.exists(os.path.join(FLAGS.log_dir, "pix2pix_log")): 128 | os.makedirs(os.path.join(FLAGS.log_dir, "pix2pix_log")) 129 | pix2pix.train(FLAGS) 130 | 131 | # Test entire network 132 | elif not FLAGS.is_train: 133 | # load sdfgan network to create new samples if necessary 134 | if FLAGS.test_from_input_path is None: 135 | with tf.Session(config=run_config) as sess_0: 136 | sdfgan = init_sdfgan(sess_0) 137 | show_all_variables() 138 | # generate a sample from sdfgan first 139 | if not sdfgan.load(FLAGS.checkpoint_dir): 140 | raise Exception("[!] Could not load SDFGAN model. Train first, then run test mode.") 141 | FLAGS.test_from_input_path = create_sdfgan_samples(sess_0, sdfgan, FLAGS) 142 | print("Saving intermediate unprocessed samples to {0}".format(FLAGS.test_from_input_path)) 143 | sess_0.close() 144 | 145 | # post-process samples (low-pass filter & mirroring) 146 | sdf = np.squeeze(np.load(FLAGS.test_from_input_path), axis=-1) 147 | sdf_lf = batch_mirr(batch_lowpass(sdf)) 148 | 149 | # create and save final samples 150 | tf.reset_default_graph() 151 | with tf.Session(config=run_config) as sess_1: 152 | pix2pix = init_pix2pix(sess_1) 153 | show_all_variables() 154 | if not pix2pix.load(FLAGS.checkpoint_dir): 155 | raise Exception("[!] Could not load Pix2Pix model. Train first, then run test mode.") 156 | sdf_hf = np.squeeze(create_pix2pix_samples(sess_1, pix2pix, sdf_lf), axis=-1) 157 | sdf_all = sdf_lf + sdf_hf 158 | sdf_save = np.concatenate((np.expand_dims(sdf_all, axis=0), 159 | np.expand_dims(sdf_lf, axis=0), 160 | np.expand_dims(sdf_hf, axis=0)), axis=0) 161 | fname = os.path.join(FLAGS.sample_dir, "full_final_samples.npy") 162 | np.save(fname, sdf_save) 163 | print("Saving final full samples to {0}, " 164 | "shape: {1} (combine_freq/low_freq/hi_freq, sample_num, dim0, dim1, dim2)" 165 | .format(fname, sdf_save.shape)) 166 | sess_1.close() 167 | 168 | else: 169 | raise Exception("[!] Model must be 'sdfgan' or 'pix2pix'.") 170 | 171 | 172 | if __name__ == '__main__': 173 | tf.app.run() 174 | -------------------------------------------------------------------------------- /model_pix2pix.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import time 3 | from glob import glob 4 | import math 5 | 6 | from ops import * 7 | from utils import * 8 | 9 | EPS = 1e-12 10 | 11 | 12 | class Pix2Pix(object): 13 | def __init__(self, sess, image_depth=64, image_height=64, image_width=64, 14 | batch_size=64, sample_num=64, gf_dim=64, df_dim=64, gan_weight=1, l1_weight=100, 15 | c_dim=1, dataset_name='shapenet_freqsplit', num_gpus=1, save_interval=200, 16 | input_fname_pattern='*.npy', checkpoint_dir=None, dataset_dir=None, log_dir=None, sample_dir=None): 17 | """ 18 | 19 | Args: 20 | sess: TensorFlow session 21 | batch_size: The size of batch. Should be specified before training. 22 | gf_dim: (optional) Dimension of gen filters in first conv layer. [64] 23 | df_dim: (optional) Dimension of discrim filters in first conv layer. [64] 24 | c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3] 25 | """ 26 | self.sess = sess 27 | 28 | self.batch_size = batch_size 29 | self.sample_num = sample_num 30 | 31 | self.image_depth = image_depth 32 | self.image_height = image_height 33 | self.image_width = image_width 34 | 35 | self.gf_dim = gf_dim 36 | self.df_dim = df_dim 37 | 38 | self.c_dim = c_dim 39 | self.save_inverval = save_interval 40 | 41 | self.gan_weight = gan_weight 42 | self.l1_weight = l1_weight 43 | 44 | self.num_gpus = num_gpus 45 | self.glob_batch_size = self.num_gpus * self.batch_size 46 | 47 | self.dataset_name = dataset_name 48 | self.input_fname_pattern = input_fname_pattern 49 | self.checkpoint_dir = checkpoint_dir 50 | self.dataset_dir = dataset_dir 51 | self.log_dir = os.path.join(log_dir, "pix2pix_log") 52 | self.sample_dir = os.path.join(sample_dir, "pix2pix_sample") 53 | self.build_model() 54 | 55 | def build_model(self): 56 | 57 | image_dims = [self.image_depth, self.image_height, self.image_width, self.c_dim] 58 | 59 | # input placeholders 60 | self.inputs = tf.placeholder( 61 | tf.float32, [self.glob_batch_size] + image_dims, name='input_images') 62 | self.targets = tf.placeholder( 63 | tf.float32, [self.glob_batch_size] + image_dims, name='target_images') 64 | self.sample_inputs = tf.placeholder( 65 | tf.float32, [self.sample_num] + image_dims, name='sample_inputs') 66 | 67 | self.n_eff = tf.placeholder(tf.int32, name='n_eff') # overall number of effective data points 68 | 69 | # initialize global lists 70 | self.G = [None] * self.num_gpus 71 | self.D = [None] * self.num_gpus 72 | self.D_ = [None] * self.num_gpus 73 | self.d_losses = [None] * self.num_gpus 74 | self.g_losses = [None] * self.num_gpus 75 | self.g_losses_gan = [None] * self.num_gpus 76 | self.g_losses_l1 = [None] * self.num_gpus 77 | self.n_effs = [None] * self.num_gpus 78 | 79 | # compute using multiple gpus 80 | with tf.variable_scope(tf.get_variable_scope()) as vscope: 81 | for gpuid in xrange(self.num_gpus): 82 | with tf.device('/gpu:%d' % gpuid): 83 | 84 | # range of data for this gpu 85 | gpu_start = gpuid * self.batch_size 86 | gpu_end = (gpuid + 1) * self.batch_size 87 | 88 | # number of effective data points 89 | gpu_n_eff = tf.reduce_min([tf.reduce_max([0, self.n_eff - gpu_start]), self.batch_size]) 90 | 91 | # create examples and pass through discriminator 92 | gpu_G = self.generator(self.inputs[gpu_start:gpu_end]) 93 | gpu_D = self.discriminator(self.inputs[gpu_start:gpu_end], self.targets[gpu_start:gpu_end]) # real 94 | gpu_D_ = self.discriminator(self.inputs[gpu_start:gpu_end], gpu_G, reuse=True) # fake pairs 95 | 96 | # compute discriminator loss 97 | gpu_d_loss = tf.reduce_mean(-(tf.log(gpu_D[:gpu_n_eff] + EPS) 98 | + tf.log(1 - gpu_D_[:gpu_n_eff] + EPS))) 99 | 100 | # compute generator loss 101 | gpu_g_loss_gan = tf.reduce_mean(-tf.log(gpu_D[:gpu_n_eff] + EPS)) 102 | gpu_g_loss_l1 = tf.reduce_mean(tf.abs(self.targets[gpu_start:gpu_end] - gpu_G)) 103 | gpu_g_loss = gpu_g_loss_gan * self.gan_weight + gpu_g_loss_l1 * self.l1_weight 104 | 105 | # add gpu-wise data to global list 106 | self.G[gpuid] = gpu_G 107 | self.D[gpuid] = gpu_D 108 | self.D_[gpuid] = gpu_D_ 109 | self.d_losses[gpuid] = gpu_d_loss 110 | self.g_losses[gpuid] = gpu_g_loss 111 | self.g_losses_gan[gpuid] = gpu_g_loss_gan 112 | self.g_losses_l1[gpuid] = gpu_g_loss_l1 113 | self.n_effs[gpuid] = gpu_n_eff 114 | 115 | # Reuse variables for the next gpu 116 | tf.get_variable_scope().reuse_variables() 117 | 118 | # concatenate across GPUs 119 | self.D = tf.concat(self.D, axis=0) 120 | self.D_ = tf.concat(self.D_, axis=0) 121 | self.G = tf.concat(self.G, axis=0) 122 | weighted_d_loss = [self.d_losses[j] * tf.cast(self.n_effs[j], tf.float32) 123 | / tf.cast(self.n_eff, tf.float32) for j in range(self.num_gpus)] 124 | weighted_g_loss = [self.g_losses[j] * tf.cast(self.n_effs[j], tf.float32) 125 | / tf.cast(self.n_eff, tf.float32) for j in range(self.num_gpus)] 126 | weighted_g_loss_gan = [tf.cast(self.g_losses_gan[j], tf.float32) * tf.cast(self.n_effs[j], tf.float32) 127 | / tf.cast(self.n_eff, tf.float32) for j in range(self.num_gpus)] 128 | weighted_g_loss_l1 = [tf.cast(self.g_losses_l1[j], tf.float32) * tf.cast(self.n_effs[j], tf.float32) 129 | / tf.cast(self.n_eff, tf.float32) for j in range(self.num_gpus)] 130 | 131 | self.d_loss = tf.reduce_sum(weighted_d_loss, axis=0) 132 | self.g_loss = tf.reduce_sum(weighted_g_loss, axis=0) 133 | self.g_loss_gan = tf.reduce_sum(weighted_g_loss_gan, axis=0) 134 | self.g_loss_l1 = tf.reduce_sum(weighted_g_loss_l1, axis=0) 135 | 136 | # summarize variables 137 | self.d_sum = histogram_summary("d", self.D) 138 | self.d__sum = histogram_summary("d_", self.D_) 139 | self.g_sum = image_summary("G", self.G[:, int(self.image_depth / 2), :, :]) 140 | 141 | self.d_loss_sum = scalar_summary("d_loss", self.d_loss) 142 | 143 | self.g_loss_sum = scalar_summary("g_loss", self.g_loss) 144 | self.g_loss_gan_sum = scalar_summary("g_loss_gan", self.g_loss_gan) 145 | self.g_loss_l1_sum = scalar_summary("g_loss_l1", self.g_loss_l1) 146 | 147 | # define trainable variables for generator and discriminator 148 | t_vars = tf.trainable_variables() 149 | self.d_vars = [var for var in t_vars if 'd_' in var.name] 150 | self.g_vars = [var for var in t_vars if 'g_' in var.name] 151 | 152 | self.sampler = self.sampler(self.sample_inputs) 153 | self.saver = tf.train.Saver() 154 | 155 | def train(self, config): 156 | """Train SFDGAN""" 157 | data = glob(os.path.join(self.dataset_dir, config.dataset, self.input_fname_pattern)) 158 | np.random.shuffle(data) 159 | 160 | # define optimization operation 161 | d_opt = tf.train.AdamOptimizer(config.d_learning_rate, beta1=config.beta1) 162 | g_opt = tf.train.AdamOptimizer(config.g_learning_rate, beta1=config.beta1) 163 | 164 | # create list of grads from different gpus 165 | global_d_grads_vars = [None] * self.num_gpus 166 | global_g_grads_vars = [None] * self.num_gpus 167 | 168 | # compute d gradients 169 | with tf.variable_scope(tf.get_variable_scope()): 170 | for gpuid in xrange(self.num_gpus): 171 | with tf.device('/gpu:%d' % gpuid): 172 | gpu_d_grads_vars = d_opt.compute_gradients(loss=self.d_losses[gpuid], var_list=self.d_vars) 173 | global_d_grads_vars[gpuid] = gpu_d_grads_vars 174 | 175 | # compute g gradients 176 | with tf.variable_scope(tf.get_variable_scope()): 177 | for gpuid in xrange(self.num_gpus): 178 | with tf.device('/gpu:%d' % gpuid): 179 | gpu_g_grads_vars = g_opt.compute_gradients(loss=self.g_losses[gpuid], var_list=self.g_vars) 180 | global_g_grads_vars[gpuid] = gpu_g_grads_vars 181 | 182 | # average gradients across gpus and apply gradients 183 | d_grads_vars = average_gradients(global_d_grads_vars) 184 | g_grads_vars = average_gradients(global_g_grads_vars) 185 | d_optim = d_opt.apply_gradients(d_grads_vars) 186 | g_optim = g_opt.apply_gradients(g_grads_vars) 187 | 188 | # compatibility across tf versions 189 | try: 190 | tf.global_variables_initializer().run() 191 | except: 192 | tf.initialize_all_variables().run() 193 | 194 | self.g_sum = merge_summary([self.d__sum, self.g_loss_gan_sum, self.g_loss_l1_sum, self.g_loss_sum, self.g_sum]) 195 | self.d_sum = merge_summary([self.d_sum, self.d_loss_sum]) 196 | 197 | self.writer = SummaryWriter(self.log_dir, self.sess.graph) 198 | sample_files = data[0:self.sample_num] 199 | 200 | sample_inputs = [np.load(sample_file)[0, :, :, :] for sample_file in sample_files] 201 | sample_targets = [np.load(sample_file)[1, :, :, :] for sample_file in sample_files] 202 | sample_in = np.array(sample_inputs).astype(np.float32)[:, :, :, :, None] 203 | sample_tg = np.array(sample_targets).astype(np.float32)[:, :, :, :, None] 204 | 205 | counter = 0 206 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 207 | if could_load: 208 | counter = checkpoint_counter 209 | print(" [*] Load SUCCESS") 210 | else: 211 | print(" [!] Load failed...") 212 | 213 | batch_idxs = int(math.ceil(min(len(data), config.train_size) / self.glob_batch_size)) - 1 214 | total_steps = config.epoch * batch_idxs 215 | prev_time = -np.inf 216 | 217 | for epoch in xrange(config.epoch): 218 | # shuffle data before training in each epoch 219 | np.random.shuffle(data) 220 | for idx in xrange(0, batch_idxs): 221 | glob_batch_files = data[idx * self.glob_batch_size:(idx + 1) * self.glob_batch_size] 222 | glob_batch_inputs = [ 223 | np.load(batch_file)[0, :, :, :] for batch_file in glob_batch_files] 224 | glob_batch_targets = [ 225 | np.load(batch_file)[1, :, :, :] for batch_file in glob_batch_files] 226 | glob_batch_in = np.array(glob_batch_inputs).astype(np.float32)[:, :, :, :, None] 227 | glob_batch_tg = np.array(glob_batch_targets).astype(np.float32)[:, :, :, :, None] 228 | 229 | n_eff = len(glob_batch_files) 230 | 231 | # Pad zeros if effective batch size is smaller than global batch size 232 | if n_eff != self.glob_batch_size: 233 | glob_batch_in = pad_glob_batch(glob_batch_in, self.glob_batch_size) 234 | glob_batch_tg = pad_glob_batch(glob_batch_tg, self.glob_batch_size) 235 | 236 | # Update D network 237 | _, summary_str = self.sess.run([d_optim, self.d_sum], 238 | feed_dict={self.inputs: glob_batch_in, 239 | self.targets: glob_batch_tg, 240 | self.n_eff: n_eff}) 241 | self.writer.add_summary(summary_str, counter) 242 | 243 | # Update G network 244 | _, summary_str = self.sess.run([g_optim, self.g_sum], 245 | feed_dict={self.inputs: glob_batch_in, 246 | self.targets: glob_batch_tg, 247 | self.n_eff: n_eff}) 248 | self.writer.add_summary(summary_str, counter) 249 | 250 | # Compute last batch accuracy and losses 251 | lossD, lossG = self.sess.run([self.d_loss, self.g_loss], 252 | feed_dict={self.inputs: glob_batch_in, 253 | self.targets: glob_batch_tg, 254 | self.n_eff: n_eff}) 255 | 256 | # get time 257 | now_time = time.time() 258 | time_per_iter = now_time - prev_time 259 | prev_time = now_time 260 | eta = (total_steps - counter + checkpoint_counter) * time_per_iter 261 | counter += 1 262 | 263 | try: 264 | timestr = time.strftime("%H:%M:%S", time.gmtime(eta)) 265 | except: 266 | timestr = '?:?:?' 267 | 268 | print("Epoch:[%3d] [%3d/%3d] Iter:[%5d] eta(h:m:s): %s, d_loss: %.8f, g_loss: %.8f" 269 | % (epoch, idx, batch_idxs, counter, timestr, lossD, lossG)) 270 | 271 | # save checkpoint and samples every save_interval steps 272 | if np.mod(counter, self.save_inverval) == 1: 273 | sample_gen = self.sess.run(self.sampler, feed_dict={self.sample_inputs: sample_in}) 274 | sample = np.concatenate((np.expand_dims(sample_in, axis=0), # sample_num x 64 x 64 x 64 x 1 275 | np.expand_dims(sample_tg, axis=0), # sample_num x 64 x 64 x 64 x 1 276 | np.expand_dims(sample_gen, axis=0)), axis=0) # sample_num x 64 x 64 x 64 x 1 277 | np.save(self.sample_dir+'/sample_{:05d}.npy' 278 | .format(counter), sample) 279 | print("[Sample] Iter {0}, saving sample size of {1}, saving checkpoint." 280 | .format(counter, self.sample_num)) 281 | self.save(config.checkpoint_dir, counter) 282 | 283 | # save last checkpoint 284 | sample_gen = self.sess.run(self.sampler, feed_dict={self.sample_inputs: sample_in}) 285 | sample = np.concatenate((np.expand_dims(sample_in[:, :, :, :, 0], axis=0), # sample_num x 64 x 64 x 64 286 | np.expand_dims(sample_tg[:, :, :, :, 0], axis=0), # sample_num x 64 x 64 x 64 287 | np.expand_dims(sample_gen[:, :, :, :, 0], axis=0)), axis=0) # sample_num x 64 x 64 x 64 288 | np.save(self.sample_dir+'/sample_{:05d}.npy' 289 | .format(counter), sample) 290 | print("[Sample] Iter {0}, saving sample size of {1}, saving checkpoint." 291 | .format(counter, self.sample_num)) 292 | self.save(config.checkpoint_dir, counter) 293 | 294 | print("[!] Training of Pix2Pix Network Complete.") 295 | 296 | def discriminator(self, discrim_inputs, discrim_targets, reuse=False): 297 | with tf.variable_scope("pix2pix_discriminator") as scope: 298 | if reuse: 299 | scope.reuse_variables() 300 | n_layers = 3 301 | layers = [] 302 | 303 | # 2x [batch, depth, height, width, in_channels] => [batch, depth, height, width, in_channels * 2] 304 | input = tf.concat([discrim_inputs, discrim_targets], axis=-1) 305 | 306 | # layer_1: [batch, 64, 64, 64, in_channels * 2] => [batch, 32, 32, 32, df_dim] 307 | convolved = conv3d(input, self.df_dim, name='d_h0_conv') 308 | rectified = lrelu(convolved, 0.2) 309 | layers.append(rectified) 310 | 311 | # layer_2: [batch, 32, 32, 32, df_dim] => [batch, 16, 16, 16, df_dim * 2] 312 | # layer_3: [batch, 16, 16, 16, df_dim * 2] => [batch, 8, 8, 8, df_dim * 4] 313 | # layer_4: [batch, 8, 8, 8, df_dim * 4] => [batch, 7, 7, 7, df_dim * 8] 314 | for i in range(n_layers): 315 | # with tf.variable_scope("layer_%d" % (len(layers) + 1)): 316 | out_channels = self.df_dim * min(2 ** (i + 1), 8) 317 | stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1 318 | convolved = conv3d(layers[-1], out_channels, 319 | d_d=stride, d_h=stride, d_w=stride, name='d_h%d_conv' % (len(layers))) 320 | normalized = batchnorm(convolved, name='d_h%d_bn' % (len(layers))) 321 | rectified = lrelu(normalized, 0.2) 322 | layers.append(rectified) 323 | 324 | # layer_5: [batch, 7, 7, df_dim * 8] => [batch, 6, 6, 1] 325 | # with tf.variable_scope("layer_%d" % (len(layers) + 1)): 326 | convolved = conv3d(rectified, output_dim=1, d_d=1, d_h=1, d_w=1, name='d_h%d_conv' % (len(layers))) 327 | output = tf.sigmoid(convolved) 328 | layers.append(output) 329 | 330 | return layers[-1] 331 | 332 | def generator(self, generator_inputs): 333 | with tf.variable_scope("pix2pix_generator") as scope: 334 | layers = [] 335 | 336 | # encoder_1: [batch, 64, 64, 64, in_channels] => [batch, 32, 32, gf_dim] 337 | output = conv3d(generator_inputs, self.gf_dim, name="g_enc_conv_0") 338 | layers.append(output) 339 | 340 | layer_specs = [ 341 | self.gf_dim * 2, # encoder_2: [batch, 32, 32, 32, gf_dim] => [batch, 16, 16, 16, gf_dim * 2] 342 | self.gf_dim * 4, # encoder_3: [batch, 16, 16, 16, gf_dim * 2] => [batch, 8, 8, 8, gf_dim * 4] 343 | self.gf_dim * 8, # encoder_4: [batch, 8, 8, 8, gf_dim * 2] => [batch, 4, 4, 4, gf_dim * 4] 344 | self.gf_dim * 8, # encoder_5: [batch, 4, 4, 4, gf_dim * 8] => [batch, 2, 2, 2, gf_dim * 8] 345 | self.gf_dim * 8, # encoder_6: [batch, 2, 2, 2, gf_dim * 8] => [batch, 1, 1, 1, gf_dim * 8] 346 | ] 347 | 348 | for out_channels in layer_specs: 349 | rectified = lrelu(layers[-1], 0.2) 350 | # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels] 351 | convolved = conv3d(rectified, out_channels, name="g_enc_cov_%d" % (len(layers))) 352 | output = batchnorm(convolved, name="g_enc_bn_%d" % (len(layers))) 353 | layers.append(output) 354 | 355 | layer_specs = [ 356 | (self.gf_dim * 8, 0.5), # decoder_6: [batch,1,1,1,gf_dim * 8] => [batch,2,2,2,gf_dim * 8 * 2] 357 | (self.gf_dim * 8, 0.5), # decoder_5: [batch,2,2,2,gf_dim * 8 * 2] => [batch,4 4,4,gf_dim * 8 * 2] 358 | (self.gf_dim * 4, 0.5), # decoder_4: [batch,4,4,4,gf_dim * 8 * 2] => [batch,8,8,8,gf_dim * 4 * 2] 359 | (self.gf_dim * 2, 0.0), # decoder_3: [batch,8,8,8,gf_dim * 4 * 2] => [batch,16,16,16,gf_dim * 2 * 2] 360 | (self.gf_dim * 1, 0.0), # decoder_2: [batch,16,16,16,gf_dim * 2 * 2] => [batch,32,32,32,gf_dim * 2] 361 | ] 362 | 363 | num_encoder_layers = len(layers) 364 | for decoder_layer, (out_channels, dropout) in enumerate(layer_specs): 365 | skip_layer = num_encoder_layers - decoder_layer - 1 366 | if decoder_layer == 0: 367 | # first decoder layer doesn't have skip connections 368 | # since it is directly connected to the skip_layer 369 | input = layers[-1] 370 | else: 371 | input = tf.concat([layers[-1], layers[skip_layer]], axis=-1) 372 | 373 | rectified = tf.nn.relu(input) 374 | # [batch, in_depth, in_height, in_width, in_channels] 375 | # => [batch, in_depth, in_height*2, in_width*2, out_channels] 376 | in_dims = rectified.get_shape().as_list() 377 | out_dims = [in_dims[0], in_dims[1] * 2, in_dims[2] * 2, in_dims[3] * 2, out_channels] 378 | output = deconv3d(rectified, out_dims, name="g_dec_conv_%d" % skip_layer) 379 | output = batchnorm(output, name="g_dec_bn_%d" % skip_layer) 380 | 381 | if dropout > 0.0: 382 | output = tf.nn.dropout(output, keep_prob=1 - dropout) 383 | 384 | layers.append(output) 385 | 386 | # decoder_1: [batch, 32, 32, 32, gf_dim * 2] => [batch, 64, 64, 64, 1] 387 | input = tf.concat([layers[-1], layers[0]], axis=-1) 388 | rectified = tf.nn.relu(input) 389 | in_dims = rectified.get_shape().as_list() 390 | out_dims = [in_dims[0], in_dims[1] * 2, in_dims[2] * 2, in_dims[3] * 2, 1] 391 | output = deconv3d(rectified, out_dims, name="g_dec_conv_0") 392 | output = tf.tanh(output) 393 | layers.append(output) 394 | 395 | return layers[-1] 396 | 397 | def sampler(self, generator_inputs): 398 | with tf.variable_scope("pix2pix_generator") as scope: 399 | scope.reuse_variables() 400 | 401 | layers = [] 402 | 403 | # encoder_1: [batch, 64, 64, 64, in_channels] => [batch, 32, 32, gf_dim] 404 | output = conv3d(generator_inputs, self.gf_dim, name="g_enc_conv_0") 405 | layers.append(output) 406 | 407 | layer_specs = [ 408 | self.gf_dim * 2, # encoder_2: [batch, 32, 32, 32, gf_dim] => [batch, 16, 16, 16, gf_dim * 2] 409 | self.gf_dim * 4, # encoder_3: [batch, 16, 16, 16, gf_dim * 2] => [batch, 8, 8, 8, gf_dim * 4] 410 | self.gf_dim * 8, # encoder_4: [batch, 8, 8, 8, gf_dim * 2] => [batch, 4, 4, 4, gf_dim * 4] 411 | self.gf_dim * 8, # encoder_5: [batch, 4, 4, 4, gf_dim * 8] => [batch, 2, 2, 2, gf_dim * 8] 412 | self.gf_dim * 8, # encoder_6: [batch, 2, 2, 2, gf_dim * 8] => [batch, 1, 1, 1, gf_dim * 8] 413 | ] 414 | 415 | for out_channels in layer_specs: 416 | rectified = lrelu(layers[-1], 0.2) 417 | # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels] 418 | convolved = conv3d(rectified, out_channels, name="g_enc_cov_%d" % (len(layers))) 419 | output = batchnorm(convolved, name="g_enc_bn_%d" % (len(layers))) 420 | layers.append(output) 421 | 422 | layer_specs = [ 423 | (self.gf_dim * 8, 0.5), # decoder_6: [batch,1,1,1,gf_dim * 8] => [batch,2,2,2,gf_dim * 8 * 2] 424 | (self.gf_dim * 8, 0.5), # decoder_5: [batch,2,2,2,gf_dim * 8 * 2] => [batch,4 4,4,gf_dim * 8 * 2] 425 | (self.gf_dim * 4, 0.5), # decoder_4: [batch,4,4,4,gf_dim * 8 * 2] => [batch,8,8,8,gf_dim * 4 * 2] 426 | (self.gf_dim * 2, 0.0), # decoder_3: [batch,8,8,8,gf_dim * 4 * 2] => [batch,16,16,16,gf_dim * 2 * 2] 427 | (self.gf_dim * 1, 0.0), # decoder_2: [batch,16,16,16,gf_dim * 2 * 2] => [batch,32,32,32,gf_dim * 2] 428 | ] 429 | 430 | num_encoder_layers = len(layers) 431 | for decoder_layer, (out_channels, dropout) in enumerate(layer_specs): 432 | skip_layer = num_encoder_layers - decoder_layer - 1 433 | if decoder_layer == 0: 434 | # first decoder layer doesn't have skip connections 435 | # since it is directly connected to the skip_layer 436 | input = layers[-1] 437 | else: 438 | input = tf.concat([layers[-1], layers[skip_layer]], axis=-1) 439 | 440 | rectified = tf.nn.relu(input) 441 | # [batch, in_depth, in_height, in_width, in_channels] 442 | # => [batch, in_depth, in_height*2, in_width*2, out_channels] 443 | in_dims = rectified.get_shape().as_list() 444 | out_dims = [in_dims[0], in_dims[1] * 2, in_dims[2] * 2, in_dims[3] * 2, out_channels] 445 | output = deconv3d(rectified, out_dims, name="g_dec_conv_%d" % skip_layer) 446 | output = batchnorm(output, name="g_dec_bn_%d" % skip_layer) 447 | 448 | if dropout > 0.0: 449 | output = tf.nn.dropout(output, keep_prob=1 - dropout) 450 | 451 | layers.append(output) 452 | 453 | # decoder_1: [batch, 32, 32, 32, gf_dim * 2] => [batch, 64, 64, 64, 1] 454 | input = tf.concat([layers[-1], layers[0]], axis=-1) 455 | rectified = tf.nn.relu(input) 456 | in_dims = rectified.get_shape().as_list() 457 | out_dims = [in_dims[0], in_dims[1] * 2, in_dims[2] * 2, in_dims[3] * 2, 1] 458 | output = deconv3d(rectified, out_dims, name="g_dec_conv_0") 459 | output = tf.tanh(output) 460 | layers.append(output) 461 | 462 | return layers[-1] 463 | 464 | @property 465 | def model_dir(self): 466 | return "Pix2Pix_" + "{}_{}_{}_{}".format( 467 | self.dataset_name.replace("_freqsplit", ""), 468 | self.image_depth, self.image_height, self.image_width) 469 | 470 | def save(self, checkpoint_dir, step): 471 | model_name = "Pix2Pix.model" 472 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 473 | 474 | if not os.path.exists(checkpoint_dir): 475 | os.makedirs(checkpoint_dir) 476 | 477 | self.saver.save(self.sess, 478 | os.path.join(checkpoint_dir, model_name), 479 | global_step=step) 480 | 481 | def load(self, checkpoint_dir): 482 | import re 483 | print(" [*] Reading checkpoints...") 484 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 485 | 486 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 487 | if ckpt and ckpt.model_checkpoint_path: 488 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 489 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 490 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0)) 491 | print(" [*] Success to read {}".format(ckpt_name)) 492 | return True, counter 493 | else: 494 | print(" [*] Failed to find a checkpoint") 495 | return False, 0 496 | -------------------------------------------------------------------------------- /model_sdfgan.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import time 3 | from glob import glob 4 | 5 | from ops import * 6 | from utils import * 7 | 8 | 9 | def conv_out_size_same(size, stride): 10 | return int(math.ceil(float(size) / float(stride))) 11 | 12 | 13 | class SDFGAN(object): 14 | def __init__(self, sess, num_gpus=1, image_depth=64, image_height=64, image_width=64, save_interval=200, 15 | batch_size=64, sample_num=64, z_dim=200, gf_dim=64, df_dim=64,c_dim=1, dataset_name='shapenet', 16 | input_fname_pattern='*.npy', checkpoint_dir=None, dataset_dir=None, log_dir=None, sample_dir=None): 17 | """ 18 | 19 | Args: 20 | sess: TensorFlow session 21 | batch_size: The size of batch. Should be specified before training. 22 | z_dim: (optional) Dimension of dim for Z. [200] 23 | gf_dim: (optional) Dimension of gen filters in first conv layer. [64] 24 | df_dim: (optional) Dimension of discrim filters in first conv layer. [64] 25 | c_dim: (optional) Dimension of channels. For sdf input, set to 1. [3] 26 | """ 27 | self.sess = sess 28 | 29 | self.batch_size = batch_size 30 | self.sample_num = sample_num 31 | 32 | self.image_depth = image_depth 33 | self.image_height = image_height 34 | self.image_width = image_width 35 | 36 | self.z_dim = z_dim 37 | 38 | self.gf_dim = gf_dim 39 | self.df_dim = df_dim 40 | 41 | self.c_dim = c_dim 42 | self.save_interval = save_interval 43 | self.num_gpus = num_gpus 44 | self.glob_batch_size = self.num_gpus * self.batch_size 45 | 46 | # batch normalization : deals with poor initialization helps gradient flow 47 | self.d_bn1 = batch_norm(name='d_bn1') 48 | self.d_bn2 = batch_norm(name='d_bn2') 49 | self.d_bn3 = batch_norm(name='d_bn3') 50 | 51 | self.g_bn0 = batch_norm(name='g_bn0') 52 | self.g_bn1 = batch_norm(name='g_bn1') 53 | self.g_bn2 = batch_norm(name='g_bn2') 54 | self.g_bn3 = batch_norm(name='g_bn3') 55 | 56 | self.dataset_name = dataset_name 57 | self.input_fname_pattern = input_fname_pattern 58 | self.checkpoint_dir = checkpoint_dir 59 | self.dataset_dir = dataset_dir 60 | self.log_dir = os.path.join(log_dir, "sdfgan_log") 61 | self.sample_dir = os.path.join(sample_dir, "sdfgan_sample") 62 | self.build_model() 63 | 64 | def build_model(self): 65 | 66 | image_dims = [self.image_depth, self.image_height, self.image_width, self.c_dim] 67 | 68 | # input placeholders 69 | self.inputs = tf.placeholder( 70 | tf.float32, [self.glob_batch_size] + image_dims, name='real_images') 71 | self.z = tf.placeholder( 72 | tf.float32, [None, self.z_dim], name='z') # latent vector 73 | self.z_sum = histogram_summary("z", self.z) 74 | 75 | self.n_eff = tf.placeholder(tf.int32, name='n_eff') # overall number of effective data points 76 | 77 | # initialize global lists 78 | self.G = [None] * self.num_gpus # generator results 79 | self.D = [None] * self.num_gpus # discriminator results for real images 80 | self.D_logits = [None] * self.num_gpus 81 | self.D_ = [None] * self.num_gpus # discriminator results for fake images 82 | self.D_logits_ = [None] * self.num_gpus 83 | self.d_loss_real = [None] * self.num_gpus 84 | self.d_loss_fake = [None] * self.num_gpus 85 | self.d_losses = [None] * self.num_gpus 86 | self.g_losses = [None] * self.num_gpus 87 | self.d_accus = [None] * self.num_gpus # discriminator accuracy 88 | self.n_effs = [None] * self.num_gpus 89 | 90 | # compute using multiple gpus 91 | with tf.variable_scope(tf.get_variable_scope()) as vscope: 92 | for gpuid in xrange(self.num_gpus): 93 | with tf.device('/gpu:%d' % gpuid): 94 | 95 | # range of data for this gpu 96 | gpu_start = gpuid * self.batch_size 97 | gpu_end = (gpuid + 1) * self.batch_size 98 | 99 | # number of effective data points 100 | gpu_n_eff = tf.reduce_min([tf.reduce_max([0, self.n_eff - gpu_start]), self.batch_size]) 101 | 102 | # create examples and pass through discriminator 103 | gpu_G = self.generator(self.z[gpu_start:gpu_end]) 104 | gpu_D, gpu_D_logits = self.discriminator(self.inputs[gpu_start:gpu_end]) 105 | gpu_D_, gpu_D_logits_ = self.discriminator(gpu_G, reuse=True) 106 | 107 | # compatibility across different tf versions 108 | def sigmoid_cross_entropy_with_logits(x, y): 109 | try: 110 | return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y) 111 | except: 112 | return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, targets=y) 113 | 114 | # compute loss and accuracy 115 | gpu_d_loss_real = tf.reduce_mean( 116 | sigmoid_cross_entropy_with_logits(gpu_D_logits[:gpu_n_eff], tf.ones_like(gpu_D[:gpu_n_eff]))) 117 | gpu_d_loss_fake = tf.reduce_mean( 118 | sigmoid_cross_entropy_with_logits(gpu_D_logits_[:gpu_n_eff], tf.zeros_like(gpu_D_[:gpu_n_eff]))) 119 | gpu_g_loss_gen = tf.reduce_mean( 120 | sigmoid_cross_entropy_with_logits(gpu_D_logits_[:gpu_n_eff], tf.ones_like(gpu_D_[:gpu_n_eff]))) 121 | gpu_d_loss = gpu_d_loss_real + gpu_d_loss_fake 122 | gpu_d_accu_real = tf.reduce_sum(tf.cast(gpu_D[:gpu_n_eff] > .5, tf.int32)) / gpu_D.get_shape()[0] 123 | gpu_d_accu_fake = tf.reduce_sum(tf.cast(gpu_D_[:gpu_n_eff] < .5, tf.int32)) / gpu_D_.get_shape()[0] 124 | gpu_d_accu = (gpu_d_accu_real + gpu_d_accu_fake) / 2 125 | 126 | # combined generator loss 127 | gpu_g_loss = gpu_g_loss_gen 128 | 129 | # add gpu-wise data to global list 130 | self.G[gpuid] = gpu_G 131 | self.D[gpuid] = gpu_D 132 | self.D_[gpuid] = gpu_D_ 133 | self.D_logits[gpuid] = gpu_D_logits 134 | self.D_logits_[gpuid] = gpu_D_logits_ 135 | self.d_loss_real[gpuid] = gpu_d_loss_real 136 | self.d_loss_fake[gpuid] = gpu_d_loss_fake 137 | self.d_losses[gpuid] = gpu_d_loss 138 | self.g_losses[gpuid] = gpu_g_loss 139 | self.d_accus[gpuid] = gpu_d_accu 140 | self.n_effs[gpuid] = gpu_n_eff 141 | 142 | # Reuse variables for the next gpu 143 | tf.get_variable_scope().reuse_variables() 144 | 145 | # concatenate across GPUs 146 | self.D = tf.concat(self.D, axis=0) 147 | self.D_ = tf.concat(self.D_, axis=0) 148 | self.G = tf.concat(self.G, axis=0) 149 | weighted_d_loss_real = [self.d_loss_real[j] * tf.cast(self.n_effs[j], tf.float32) 150 | / tf.cast(self.n_eff, tf.float32) for j in range(self.num_gpus)] 151 | weighted_d_loss_fake = [self.d_loss_fake[j] * tf.cast(self.n_effs[j], tf.float32) 152 | / tf.cast(self.n_eff, tf.float32) for j in range(self.num_gpus)] 153 | weighted_d_loss = [self.d_losses[j] * tf.cast(self.n_effs[j], tf.float32) 154 | / tf.cast(self.n_eff, tf.float32) for j in range(self.num_gpus)] 155 | weighted_g_loss = [self.g_losses[j] * tf.cast(self.n_effs[j], tf.float32) 156 | / tf.cast(self.n_eff, tf.float32) for j in range(self.num_gpus)] 157 | weighted_d_accu = [tf.cast(self.d_accus[j], tf.float32) * tf.cast(self.n_effs[j], tf.float32) 158 | / tf.cast(self.n_eff, tf.float32) for j in range(self.num_gpus)] 159 | 160 | self.d_loss_real = tf.reduce_sum(weighted_d_loss_real, axis=0) 161 | self.d_loss_fake = tf.reduce_sum(weighted_d_loss_fake, axis=0) 162 | self.d_loss = tf.reduce_sum(weighted_d_loss, axis=0) 163 | self.g_loss = tf.reduce_sum(weighted_g_loss, axis=0) 164 | self.d_accu = tf.reduce_sum(weighted_d_accu, axis=0) 165 | 166 | # summarize variables 167 | self.d_sum = histogram_summary("d", self.D) 168 | self.d__sum = histogram_summary("d_", self.D_) 169 | self.g_sum = image_summary("G", self.G[:, int(self.image_depth / 2), :, :]) 170 | 171 | self.d_loss_real_sum = scalar_summary("d_loss_real", self.d_loss_real) 172 | self.d_loss_fake_sum = scalar_summary("d_loss_fake", self.d_loss_fake) 173 | self.d_loss_sum = scalar_summary("d_loss", self.d_loss) 174 | 175 | self.g_loss_sum = scalar_summary("g_loss", self.g_loss) 176 | self.d_accu_sum = scalar_summary("d_accu", self.d_accu) 177 | 178 | # define trainable variables for generator and discriminator 179 | t_vars = tf.trainable_variables() 180 | self.d_vars = [var for var in t_vars if 'd_' in var.name] 181 | self.g_vars = [var for var in t_vars if 'g_' in var.name] 182 | 183 | self.sampler = self.generator(self.z, reuse=True, is_train=False) 184 | self.saver = tf.train.Saver() 185 | 186 | def train(self, config): 187 | """Train SFDGAN""" 188 | data = glob(os.path.join(self.dataset_dir, config.dataset, self.input_fname_pattern)) 189 | np.random.shuffle(data) 190 | 191 | # define optimization operation 192 | d_opt = tf.train.AdamOptimizer(config.d_learning_rate, beta1=config.beta1) 193 | g_opt = tf.train.AdamOptimizer(config.g_learning_rate, beta1=config.beta1) 194 | 195 | # create list of grads from different gpus 196 | global_d_grads_vars = [None] * self.num_gpus 197 | global_g_grads_vars = [None] * self.num_gpus 198 | 199 | # compute d gradients 200 | with tf.variable_scope(tf.get_variable_scope()): 201 | for gpuid in xrange(self.num_gpus): 202 | with tf.device('/gpu:%d' % gpuid): 203 | gpu_d_grads_vars = d_opt.compute_gradients(loss=self.d_losses[gpuid], var_list=self.d_vars) 204 | global_d_grads_vars[gpuid] = gpu_d_grads_vars 205 | 206 | # compute g gradients 207 | with tf.variable_scope(tf.get_variable_scope()): 208 | for gpuid in xrange(self.num_gpus): 209 | with tf.device('/gpu:%d' % gpuid): 210 | gpu_g_grads_vars = g_opt.compute_gradients(loss=self.g_losses[gpuid], var_list=self.g_vars) 211 | global_g_grads_vars[gpuid] = gpu_g_grads_vars 212 | 213 | # average gradients across gpus and apply gradients 214 | d_grads_vars = average_gradients(global_d_grads_vars) 215 | g_grads_vars = average_gradients(global_g_grads_vars) 216 | d_optim = d_opt.apply_gradients(d_grads_vars) 217 | g_optim = g_opt.apply_gradients(g_grads_vars) 218 | 219 | # compatibility across tf versions 220 | try: 221 | tf.global_variables_initializer().run() 222 | except: 223 | tf.initialize_all_variables().run() 224 | 225 | self.g_sum = merge_summary([self.z_sum, self.d__sum, self.g_sum, self.d_loss_fake_sum, self.g_loss_sum]) 226 | self.d_sum = merge_summary([self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum]) 227 | 228 | self.writer = SummaryWriter(self.log_dir, self.sess.graph) 229 | sample_z = np.random.uniform(-1, 1, size=(self.sample_num, self.z_dim)) 230 | 231 | counter = 0 232 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 233 | if could_load: 234 | counter = checkpoint_counter 235 | print(" [*] Load SUCCESS") 236 | else: 237 | print(" [!] Load failed...") 238 | 239 | d_accu_last_batch = .5 240 | batch_idxs = int(math.ceil(min(len(data), config.train_size) / self.glob_batch_size)) - 1 241 | total_steps = config.epoch * batch_idxs 242 | prev_time = -np.inf 243 | 244 | for epoch in xrange(config.epoch): 245 | # shuffle data before training in each epoch 246 | np.random.shuffle(data) 247 | for idx in xrange(0, batch_idxs - 1): 248 | glob_batch_files = data[idx * self.glob_batch_size:(idx + 1) * self.glob_batch_size] 249 | glob_batch = [ 250 | np.load(batch_file)[0, :, :, :] for batch_file in glob_batch_files] 251 | glob_batch_images = np.array(glob_batch).astype(np.float32)[:, :, :, :, None] 252 | 253 | glob_batch_z = np.random.uniform(-1, 1, [self.glob_batch_size, self.z_dim]) \ 254 | .astype(np.float32) 255 | n_eff = len(glob_batch_files) 256 | 257 | # Pad zeros if effective batch size is smaller than global batch size 258 | if n_eff != self.glob_batch_size: 259 | glob_batch_images = pad_glob_batch(glob_batch_images, self.glob_batch_size) 260 | 261 | # Update D network if accuracy in last batch <= 80% 262 | if d_accu_last_batch < .8: 263 | # Update D network 264 | _, summary_str = self.sess.run([d_optim, self.d_sum], 265 | feed_dict={self.inputs: glob_batch_images, 266 | self.z: glob_batch_z, 267 | self.n_eff: n_eff}) 268 | self.writer.add_summary(summary_str, counter) 269 | 270 | # Update G network 271 | _, summary_str = self.sess.run([g_optim, self.g_sum], 272 | feed_dict={self.z: glob_batch_z, 273 | self.n_eff: n_eff}) 274 | self.writer.add_summary(summary_str, counter) 275 | 276 | # Compute last batch accuracy and losses 277 | d_accu_last_batch, errD_fake, errD_real, errG \ 278 | = self.sess.run([self.d_accu, self.d_loss_fake, self.d_loss_real, self.g_loss], 279 | feed_dict={self.inputs: glob_batch_images, 280 | self.z: glob_batch_z, 281 | self.n_eff: n_eff}) 282 | self.writer.add_summary(summary_str, counter) 283 | 284 | # get time 285 | now_time = time.time() 286 | time_per_iter = now_time - prev_time 287 | prev_time = now_time 288 | eta = (total_steps - counter + checkpoint_counter) * time_per_iter 289 | counter += 1 290 | 291 | try: 292 | timestr = time.strftime("%H:%M:%S", time.gmtime(eta)) 293 | except: 294 | timestr = '?:?:?' 295 | 296 | print("Epoch:[%3d] [%3d/%3d] Iter:[%5d] eta(h:m:s): %s, d_loss: %.8f, g_loss: %.8f, d_accu: %.4f" 297 | % (epoch, idx, batch_idxs, 298 | counter, timestr, errD_fake + errD_real, errG, d_accu_last_batch)) 299 | 300 | # save model and samples every save_interval iterations 301 | if np.mod(counter, self.save_interval) == 1: 302 | samples = self.sess.run(self.sampler,feed_dict={self.z: sample_z}) 303 | np.save(self.sample_dir+'/sample_{:05d}.npy'.format(counter), samples) 304 | print("[Sample] Iter {0}, saving sample size of {1}, saving checkpoint." 305 | .format(counter, self.sample_num)) 306 | self.save(config.checkpoint_dir, counter) 307 | 308 | # save last checkpoint 309 | samples = self.sess.run(self.sampler,feed_dict={self.z: sample_z}) 310 | np.save(self.sample_dir+'/sample_{:05d}.npy'.format(counter), samples) 311 | print("[Sample] Iter {0}, saving sample size of {1}, saving checkpoint.".format(counter, self.sample_num)) 312 | self.save(config.checkpoint_dir, counter) 313 | 314 | print("{!] Training of SDFGAN Complete.") 315 | 316 | def discriminator(self, image, reuse=False): 317 | with tf.variable_scope("sdfgan_discriminator") as scope: 318 | if reuse: 319 | scope.reuse_variables() 320 | 321 | h0 = lrelu(conv3d(image, self.df_dim, name='d_h0_conv')) 322 | h1 = lrelu(self.d_bn1(conv3d(h0, self.df_dim * 2, name='d_h1_conv'))) 323 | h2 = lrelu(self.d_bn2(conv3d(h1, self.df_dim * 4, name='d_h2_conv'))) 324 | h3 = lrelu(self.d_bn3(conv3d(h2, self.df_dim * 8, name='d_h3_conv'))) 325 | h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, name='d_h3_lin') 326 | 327 | return tf.nn.sigmoid(h4), h4 328 | 329 | def generator(self, z, reuse=False, is_train=True): 330 | with tf.variable_scope("sdfgan_generator") as scope: 331 | if reuse: 332 | scope.reuse_variables() 333 | 334 | s_d, s_h, s_w = self.image_depth, self.image_height, self.image_width 335 | s_d2, s_h2, s_w2 = conv_out_size_same(s_d, 2), conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2) 336 | s_d4, s_h4, s_w4 = conv_out_size_same(s_d2, 2), conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2) 337 | s_d8, s_h8, s_w8 = conv_out_size_same(s_d4, 2), conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2) 338 | s_d16, s_h16, s_w16 = conv_out_size_same(s_d8, 2), conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2) 339 | 340 | # project `z` and reshape 341 | self.z_, self.h0_w, self.h0_b = linear( 342 | z, self.gf_dim * 8 * s_d16 * s_h16 * s_w16, 'g_h0_lin', with_w=True) 343 | 344 | self.h0 = tf.reshape( 345 | self.z_, [-1, s_d16, s_h16, s_w16, self.gf_dim * 8]) 346 | h0 = tf.nn.relu(self.g_bn0(self.h0, train=is_train)) 347 | 348 | self.h1, self.h1_w, self.h1_b = deconv3d( 349 | h0, [self.batch_size, s_d8, s_h8, s_w8, self.gf_dim * 4], name='g_h1', with_w=True) 350 | h1 = tf.nn.relu(self.g_bn1(self.h1, train=is_train)) 351 | 352 | h2, self.h2_w, self.h2_b = deconv3d( 353 | h1, [self.batch_size, s_d4, s_h4, s_w4, self.gf_dim * 2], name='g_h2', with_w=True) 354 | h2 = tf.nn.relu(self.g_bn2(h2, train=is_train)) 355 | 356 | h3, self.h3_w, self.h3_b = deconv3d( 357 | h2, [self.batch_size, s_d2, s_h2, s_w2, self.gf_dim * 1], name='g_h3', with_w=True) 358 | h3 = tf.nn.relu(self.g_bn3(h3, train=is_train)) 359 | 360 | h4, self.h4_w, self.h4_b = deconv3d( 361 | h3, [self.batch_size, s_d, s_h, s_w, self.c_dim], name='g_h4', with_w=True) 362 | 363 | return tf.nn.tanh(h4) 364 | 365 | @property 366 | def model_dir(self): 367 | return "SDFGAN_" + "{}_{}_{}_{}".format( 368 | self.dataset_name, 369 | self.image_depth, self.image_height, self.image_width) 370 | 371 | def save(self, checkpoint_dir, step): 372 | model_name = "SDFGAN.model" 373 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 374 | 375 | if not os.path.exists(checkpoint_dir): 376 | os.makedirs(checkpoint_dir) 377 | 378 | self.saver.save(self.sess, 379 | os.path.join(checkpoint_dir, model_name), 380 | global_step=step) 381 | 382 | def load(self, checkpoint_dir): 383 | import re 384 | print(" [*] Reading checkpoints...") 385 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 386 | 387 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 388 | if ckpt and ckpt.model_checkpoint_path: 389 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 390 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 391 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0)) 392 | print(" [*] Success to read {}".format(ckpt_name)) 393 | return True, counter 394 | else: 395 | print(" [*] Failed to find a checkpoint") 396 | return False, 0 397 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | try: 4 | image_summary = tf.summary.image 5 | scalar_summary = tf.summary.scalar 6 | histogram_summary = tf.summary.histogram 7 | merge_summary = tf.summary.merge 8 | SummaryWriter = tf.summary.FileWriter 9 | except: 10 | image_summary = tf.image_summary 11 | scalar_summary = tf.scalar_summary 12 | histogram_summary = tf.histogram_summary 13 | merge_summary = tf.merge_summary 14 | SummaryWriter = tf.train.SummaryWriter 15 | 16 | if "concat_v2" in dir(tf): 17 | def concat(tensors, axis, *args, **kwargs): 18 | return tf.concat_v2(tensors, axis, *args, **kwargs) 19 | else: 20 | def concat(tensors, axis, *args, **kwargs): 21 | return tf.concat(tensors, axis, *args, **kwargs) 22 | 23 | 24 | class batch_norm(object): 25 | def __init__(self, epsilon=1e-5, momentum=0.9, name="batch_norm"): 26 | with tf.variable_scope(name): 27 | self.epsilon = epsilon 28 | self.momentum = momentum 29 | self.name = name 30 | 31 | def __call__(self, x, train=True): 32 | return tf.contrib.layers.batch_norm(x, 33 | decay=self.momentum, 34 | updates_collections=None, 35 | epsilon=self.epsilon, 36 | scale=True, 37 | is_training=train, 38 | scope=self.name) 39 | 40 | 41 | def batchnorm(input, name="batchnorm"): 42 | with tf.variable_scope(name): 43 | # this block looks like it has 3 inputs on the graph unless we do this 44 | input = tf.identity(input) 45 | 46 | channels = input.get_shape()[-1] 47 | offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer()) 48 | scale = tf.get_variable("scale", [channels], dtype=tf.float32, initializer=tf.random_normal_initializer(1.0, 0.02)) 49 | mean, variance = tf.nn.moments(input, axes=[0, 1, 2, 3], keep_dims=False) 50 | variance_epsilon = 1e-5 51 | normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon) 52 | return normalized 53 | 54 | 55 | def conv3d(input_, output_dim, 56 | k_d=4, k_h=4, k_w=4, d_d=2, d_h=2, d_w=2, stddev=0.02, 57 | name="conv3d"): 58 | with tf.variable_scope(name): 59 | w = tf.get_variable('w', [k_d, k_h, k_w, input_.get_shape()[-1], output_dim], 60 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 61 | conv = tf.nn.conv3d(input_, w, strides=[1, d_d, d_h, d_w, 1], padding='SAME') 62 | 63 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 64 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 65 | 66 | return conv 67 | 68 | 69 | def deconv3d(input_, output_shape, 70 | k_d=4, k_h=4, k_w=4, d_d=2, d_h=2, d_w=2, stddev=0.02, 71 | name="deconv3d", with_w=False): 72 | with tf.variable_scope(name): 73 | # filter : [height, width, output_channels, in_channels] 74 | w = tf.get_variable('w', [k_d, k_h, k_w, output_shape[-1], input_.get_shape()[-1]], 75 | initializer=tf.random_normal_initializer(stddev=stddev)) 76 | 77 | try: 78 | deconv = tf.nn.conv3d_transpose(input_, w, output_shape=output_shape, 79 | strides=[1, d_d, d_h, d_w, 1]) 80 | 81 | # Support for verisons of TensorFlow before 0.7.0 82 | except AttributeError: 83 | deconv = tf.nn.deconv3d(input_, w, output_shape=output_shape, 84 | strides=[1, d_d, d_h, d_w, 1]) 85 | 86 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 87 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 88 | 89 | if with_w: 90 | return deconv, w, biases 91 | else: 92 | return deconv 93 | 94 | 95 | def lrelu(x, leak=0.2, name="lrelu"): 96 | return tf.maximum(x, leak * x) 97 | 98 | 99 | def linear(input_, output_size, name='linear', stddev=0.02, bias_start=0.0, with_w=False): 100 | shape = input_.get_shape().as_list() 101 | 102 | with tf.variable_scope(name): 103 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 104 | tf.random_normal_initializer(stddev=stddev)) 105 | bias = tf.get_variable("bias", [output_size], 106 | initializer=tf.constant_initializer(bias_start)) 107 | if with_w: 108 | return tf.matmul(input_, matrix) + bias, matrix, bias 109 | else: 110 | return tf.matmul(input_, matrix) + bias 111 | 112 | 113 | def freq_split(s, r=8, mask_type='boxed'): 114 | s = np.fft.fftn(s) 115 | dims = s.shape 116 | mids = [int(dims[i] / 2) for i in range(len(dims))] 117 | 118 | # frequency mask 119 | if mask_type == 'circular': 120 | U, V, W = np.mgrid[-mids[0]:mids[0], -mids[1]:mids[1], -mids[2]:mids[2]] 121 | D = np.sqrt(np.square(U) + np.square(V) + np.square(W)) 122 | lf_mask = np.fft.ifftshift(np.less_equal(D, r * np.ones_like(D)).astype(np.complex_)) 123 | hf_mask = np.ones_like(lf_mask, dtype=np.complex_) - lf_mask 124 | 125 | elif mask_type == 'boxed': 126 | U, V, W = np.mgrid[-mids[0]:mids[0], -mids[1]:mids[1], -mids[2]:mids[2]] 127 | mask = np.array([U < r, U >= -r, V < r, V >= -r, W < r, W >= -r]) 128 | lf_mask = np.fft.ifftshift(np.all(mask, axis=0).astype(np.complex_)) 129 | hf_mask = np.ones_like(lf_mask, dtype=np.complex_) - lf_mask 130 | 131 | else: 132 | raise Exception('mask_type undefined.') 133 | 134 | # apply mask 135 | s_lf = np.multiply(s, lf_mask) 136 | s_hf = np.multiply(s, hf_mask) 137 | 138 | s_lf = np.real(np.fft.ifftn(s_lf)) 139 | s_hf = np.real(np.fft.ifftn(s_hf)) 140 | 141 | return s_lf, s_hf 142 | 143 | 144 | def batch_lowpass(s_batch, r=8, mask_type='boxed'): 145 | assert len(s_batch.shape) > 3 146 | process_list = [] 147 | for i in range(s_batch.shape[0]): 148 | s_lf, _ = freq_split(s_batch[i, :, :, :], r=r, mask_type=mask_type) 149 | process_list.append(s_lf) 150 | 151 | return np.array(process_list) 152 | 153 | 154 | def batch_mirr(s_batch, mirr='l'): 155 | assert len(s_batch.shape) > 3 156 | process_list = [] 157 | for i in range(s_batch.shape[0]): 158 | s = s_batch[i] 159 | if mirr == 'r': 160 | sr = s[:, :, int(s.shape[2] / 2):] 161 | sl = np.flip(sr, axis=2) 162 | s = np.concatenate((sl, sr), axis=2) 163 | elif mirr == 'l': 164 | sl = s[:, :, :int(s.shape[2] / 2)] 165 | sr = np.flip(sl, axis=2) 166 | s = np.concatenate((sl, sr), axis=2) 167 | process_list.append(s) 168 | 169 | return np.array(process_list) 170 | 171 | 172 | 173 | def pad_glob_batch(glob_batch_images, glob_batch_size): 174 | """pad tensor with zeros if smaller than a full batch""" 175 | pad_shape = [j for j in glob_batch_images.shape] 176 | pad_shape[0] = glob_batch_size - pad_shape[0] 177 | pad_zeros = np.zeros(pad_shape) 178 | glob_batch_inputs = np.concatenate((glob_batch_images, pad_zeros), axis=0) 179 | 180 | return glob_batch_inputs 181 | 182 | 183 | def average_gradients(tower_grads): 184 | """Calculate the average gradient for each shared variable across all towers. 185 | 186 | Note that this function provides a synchronization point across all towers. 187 | 188 | Args: 189 | tower_grads: List of lists of (gradient, variable) tuples. The outer list 190 | is over individual gradients. The inner list is over the gradient 191 | calculation for each tower. 192 | Returns: 193 | List of pairs of (gradient, variable) where the gradient has been averaged 194 | across all towers. 195 | """ 196 | # Borrowed code from 197 | # https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py 198 | 199 | average_grads = [] 200 | for grad_and_vars in zip(*tower_grads): 201 | # Note that each grad_and_vars looks like the following: 202 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 203 | grads = [] 204 | for g, _ in grad_and_vars: 205 | # Add 0 dimension to the gradients to represent the tower. 206 | expanded_g = tf.expand_dims(g, 0) 207 | 208 | # Append on a 'tower' dimension which we will average over below. 209 | grads.append(expanded_g) 210 | 211 | # Average over the 'tower' dimension. 212 | grad = tf.concat(axis=0, values=grads) 213 | grad = tf.reduce_mean(grad, 0) 214 | 215 | # Keep in mind that the Variables are redundant because they are shared 216 | # across towers. So .. we will just return the first tower's pointer to 217 | # the Variable. 218 | v = grad_and_vars[0][1] 219 | grad_and_var = (grad, v) 220 | average_grads.append(grad_and_var) 221 | return average_grads 222 | -------------------------------------------------------------------------------- /stream_freqsplit.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from glob import glob 3 | import multiprocessing 4 | from os import path, makedirs 5 | from tqdm import tqdm 6 | import os 7 | import numpy as np 8 | import argparse 9 | 10 | 11 | def freq_split(s, r=8, mask_type='boxed'): 12 | ''' 13 | Split frequencies of a given 3D signal 14 | :param s: input signal matrix of shape (dim_0, dim_1, dim_2) 15 | :param r: number of modes to split 16 | :param mask_type: 'boxed' or 'circular' 17 | :return: s_lf, s_hf: low and high frequencies of the signal 18 | ''' 19 | 20 | s = np.fft.fftn(s) 21 | dims = s.shape 22 | mids = [int(dims[i] / 2) for i in range(len(dims))] 23 | 24 | # frequency mask 25 | if mask_type == 'circular': 26 | U, V, W = np.mgrid[-mids[0]:mids[0], -mids[1]:mids[1], -mids[2]:mids[2]] 27 | D = np.sqrt(np.square(U) + np.square(V) + np.square(W)) 28 | lf_mask = np.fft.ifftshift(np.less_equal(D, r * np.ones_like(D)).astype(np.complex_)) 29 | hf_mask = np.ones_like(lf_mask, dtype=np.complex_) - lf_mask 30 | 31 | elif mask_type == 'boxed': 32 | U, V, W = np.mgrid[-mids[0]:mids[0], -mids[1]:mids[1], -mids[2]:mids[2]] 33 | mask = np.array([U < r, U >= -r, V < r, V >= -r, W < r, W >= -r]) 34 | lf_mask = np.fft.ifftshift(np.all(mask, axis=0).astype(np.complex_)) 35 | hf_mask = np.ones_like(lf_mask, dtype=np.complex_) - lf_mask 36 | 37 | else: 38 | raise Exception('mask_type undefined.') 39 | 40 | # apply mask 41 | s_lf = np.multiply(s, lf_mask) 42 | s_hf = np.multiply(s, hf_mask) 43 | 44 | s_lf = np.real(np.fft.ifftn(s_lf)) 45 | s_hf = np.real(np.fft.ifftn(s_hf)) 46 | 47 | return s_lf, s_hf 48 | 49 | 50 | def process_mesh(read_path, write_path): 51 | # read mesh 52 | s = np.load(read_path)[0, :, :, :] 53 | s_lf, s_hf = freq_split(s, r=8) 54 | s_out = np.array([s_lf, s_hf]) 55 | np.save(write_path, s_out) 56 | 57 | 58 | def helper(args): 59 | return process_mesh(*args) 60 | 61 | 62 | def base_name(str): 63 | """return the base name of a path, i.e folder name""" 64 | return os.path.basename(os.path.normpath(str)) 65 | 66 | 67 | def cond_mkdir(directory): 68 | """Conditionally make directory if it does not exist""" 69 | if not path.exists(directory): 70 | makedirs(directory) 71 | else: 72 | print('Directory ' + directory + ' existed. Did not create.') 73 | 74 | 75 | if __name__ == '__main__': 76 | 77 | # argument parsing 78 | parser = argparse.ArgumentParser(description="process mesh to sdf using python multiprocessing module") 79 | parser.add_argument('datadir', type=str, help='Path to original dataset.') 80 | parser.add_argument('savedir', type=str, help='Path to saving processed dataset.') 81 | parser.add_argument('ncore', type=int, default=0, help='Number of cores. Enter 0 for maximum cpu cores available.') 82 | parser.add_argument('nmesh', type=int, default=0, help='Process only nmesh. Enter 0 to process all mesh.') 83 | argument = parser.parse_args() 84 | 85 | data_dir = argument.datadir 86 | save_dir = argument.savedir 87 | 88 | if argument.ncore == 0: 89 | num_cores = multiprocessing.cpu_count() 90 | print("Using max cpu cores = %d" % (num_cores)) 91 | else: 92 | num_cores = argument.ncore 93 | 94 | cond_mkdir(save_dir) 95 | read_loc = glob(os.path.join(data_dir, "*")) 96 | write_loc = [] 97 | 98 | for d_in in read_loc: 99 | d_out = path.join(save_dir, base_name(d_in)) 100 | write_loc.append(d_out) 101 | 102 | if argument.nmesh != 0: 103 | read_loc = read_loc[:argument.nmesh] 104 | write_loc = write_loc[:argument.nmesh] 105 | 106 | args = [] 107 | 108 | # retrieve mesh according to file_names with multiprocessing 109 | for counter, (read_path, write_path) in enumerate(zip(read_loc, write_loc)): 110 | args.append((read_path, write_path)) 111 | 112 | print("Starting multithreaded processing with {0} cores.".format(multiprocessing.cpu_count())) 113 | 114 | pool = multiprocessing.Pool() 115 | 116 | with tqdm(total=len(read_loc)) as pbar: 117 | for _, _ in tqdm(enumerate(pool.imap_unordered(helper, args))): 118 | pbar.update() 119 | pbar.close() 120 | pool.close() 121 | pool.join() 122 | -------------------------------------------------------------------------------- /test_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # import global variables PROJ_DIR and JOB_NAME 4 | . ./define_path.sh 5 | 6 | # variables 7 | JOB_DIR=${PROJ_DIR}/jobs/${JOB_NAME} 8 | DATASET_DIR=${PROJ_DIR}/data 9 | LOG_DIR=${JOB_DIR}/logs 10 | SAMPLE_DIR=${JOB_DIR}/samples 11 | CHECKPOINT_DIR=${JOB_DIR}/checkpoint 12 | 13 | # execute 14 | python main.py \ 15 | --batch_size=64 \ 16 | --dataset_dir=${DATASET_DIR} \ 17 | --log_dir=${LOG_DIR} \ 18 | --sample_dir=${SAMPLE_DIR} \ 19 | --checkpoint_dir=${CHECKPOINT_DIR} \ 20 | --dataset=${DATASET} \ 21 | --sample_num=64 22 | -------------------------------------------------------------------------------- /train_pix2pix.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # import global variables PROJ_DIR and JOB_NAME 4 | . ./define_path.sh 5 | 6 | # variables 7 | JOB_DIR=${PROJ_DIR}/jobs/${JOB_NAME} 8 | DATASET_DIR=${PROJ_DIR}/data 9 | LOG_DIR=${JOB_DIR}/logs 10 | SAMPLE_DIR=${JOB_DIR}/samples 11 | CHECKPOINT_DIR=${JOB_DIR}/checkpoint 12 | 13 | # execute 14 | python main.py \ 15 | --model=pix2pix \ 16 | --is_train \ 17 | --epoch=500 \ 18 | --batch_size=32 \ 19 | --dataset_dir=${DATASET_DIR} \ 20 | --log_dir=${LOG_DIR} \ 21 | --sample_dir=${SAMPLE_DIR} \ 22 | --checkpoint_dir=${CHECKPOINT_DIR} \ 23 | --num_gpus=4 \ 24 | --dataset=${DATASET}_freqsplit \ 25 | --g_learning_rate=0.0005 \ 26 | --d_learning_rate=0.0002 \ 27 | --gan_weight=1 \ 28 | --l1_weight=100 \ 29 | --image_depth=64 \ 30 | --beta1=0.5 \ 31 | -------------------------------------------------------------------------------- /train_sdfgan.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # import global variables PROJ_DIR and JOB_NAME 4 | . ./define_path.sh 5 | 6 | # variables 7 | JOB_DIR=${PROJ_DIR}/jobs/${JOB_NAME} 8 | DATASET_DIR=${PROJ_DIR}/data 9 | LOG_DIR=${JOB_DIR}/logs 10 | SAMPLE_DIR=${JOB_DIR}/samples 11 | CHECKPOINT_DIR=${JOB_DIR}/checkpoint 12 | 13 | # execute 14 | python main.py \ 15 | --model=sdfgan \ 16 | --is_train \ 17 | --epoch=500 \ 18 | --batch_size=64 \ 19 | --dataset_dir=${DATASET_DIR} \ 20 | --log_dir=${LOG_DIR} \ 21 | --sample_dir=${SAMPLE_DIR} \ 22 | --checkpoint_dir=${CHECKPOINT_DIR} \ 23 | --num_gpus=4 \ 24 | --dataset=${DATASET} \ 25 | --g_learning_rate=0.0005 \ 26 | --d_learning_rate=0.0002 \ 27 | --image_depth=64 \ 28 | --beta1=0.5 \ 29 | --is_new 30 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some codes from https://github.com/Newmu/dcgan_code 3 | """ 4 | from __future__ import division 5 | import math 6 | import os 7 | import pprint 8 | import numpy as np 9 | 10 | import tensorflow as tf 11 | import tensorflow.contrib.slim as slim 12 | 13 | pp = pprint.PrettyPrinter() 14 | 15 | get_stddev = lambda x, k_h, k_w: 1 / math.sqrt(k_w * k_h * x.get_shape()[-1]) 16 | 17 | 18 | def show_all_variables(): 19 | model_vars = tf.trainable_variables() 20 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 21 | 22 | 23 | def create_sdfgan_samples(sess, sdfgan, config): 24 | z_sample = np.random.uniform(-1, 1, size=(config.sample_num,sdfgan.z_dim)) 25 | samples = sess.run(sdfgan.sampler, feed_dict={sdfgan.z: z_sample}) 26 | fname = os.path.join(config.sample_dir, "sdfgan_final_samples.npy") 27 | np.save(fname, samples) 28 | 29 | return fname 30 | 31 | 32 | def create_pix2pix_samples(sess, pix2pix, sample_in): 33 | if len(sample_in.shape) == 4: 34 | sample_in = np.expand_dims(sample_in, axis=-1) 35 | sample_out = sess.run(pix2pix.sampler, feed_dict={pix2pix.sample_inputs: sample_in}) 36 | 37 | return sample_out 38 | --------------------------------------------------------------------------------