├── README.md
├── define_path.sh
├── images
├── car_15_small.gif
├── car_sample.stl
├── furniture-render.gif
└── morph.gif
├── main.py
├── model_pix2pix.py
├── model_sdfgan.py
├── ops.py
├── stream_freqsplit.py
├── test_all.sh
├── train_pix2pix.sh
├── train_sdfgan.sh
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # 3D-SDFGAN
2 | 3D Signed Distance Function Based Generative Adversarial Networks
3 |
4 |
5 |
6 | 
7 |
8 |
9 |
10 |
11 |
12 | ## About this study
13 | This study seeks to generate realistic looking, mesh-based 3D models by using GANs. Training is based on [ShapeNetCore dataset](https://www.shapenet.org/) that has been post-processed into 64x64x64 signed distance function fields. More details about this study can be found in [this paper](https://arxiv.org/abs/1709.07581)
14 |
15 | ## Reproducing the results
16 | ### Collecting the dataset
17 | First, define the project directory to store the data and computed results (checkpoints, logs, samples). Change the project directory, job name and dataset to train in the file `define_path.sh`. These parameters will be called globally throughout the training and testing:
18 | ```bash
19 | # variables to define by user
20 | PROJ_DIR=/Path/To/Your/Project/Directoy
21 | JOB_NAME=job0-yourjobname
22 | DATASET=synset_02958343_car
23 | ```
24 | run the script to export variables to the current shell:
25 | ```
26 | sudo chmod +x *.sh
27 | ./define_path.sh
28 | ```
29 | go to project directory and download data (~28.8G. Might take quite long.):
30 | ```
31 | cd $PROJ_DIR
32 | wget http://island.me.berkeley.edu/data/data.tar.gz
33 | tar -xvf data.tar
34 | ```
35 | go back to code directory to run frequency spliting code for a dataset (e.g. car dataset) for running pix2pix:
36 | ```
37 | cd /Path/To/SDFGAN
38 | python stream_freqsplit.py ${PROJ_DIR}/data/synset_02958343_car ${PROJ_DIR}/data/synset_02958343_car_freqsplit 0 0
39 | ```
40 | the above code will run the freqsplit algorithm on multiple threads. Trailing 0 and 0 defaults the code to use all available threads and convert all data within the dataaset.
41 |
42 | ### Running the sdfgan part of the code
43 | ```
44 | ./train_sdfgan.sh
45 | ```
46 | ### Running the pix2pix part of the code
47 | ```
48 | ./train_pix2pix.sh
49 | ```
50 | ### Generate test results
51 | ```
52 | ./test_all.sh
53 | ```
54 | ### Post processing for mesh
55 | Codes for postprocessing the results to create mesh is not including in this current repository. However, it can be easily achieved using a Marching-Cubes Algorithm. The mesh in the example above is further post-processed using 3 steps of Laplacian smoothing, quadratic mesh decimation by half, and hole-filling for the holes at the front and back as a result of boundary effects caused by the limitations of the bounding box.
56 |
57 | ## Open Source Code Credits
58 | Borrows code from the following repositories:
59 | * [DCGAN-tensorflow](http://carpedm20.github.io/) repository by Taehoon Kim
60 | * [Pix2pix-tensorflow](https://github.com/affinelayer/pix2pix-tensorflow) repository by affinelayer
61 |
62 | Relied on the following Open Source projects for 3D pre and post processing of 3D mesh / signed distance function:
63 | * [Libigl](https://github.com/libigl/libigl) interactive graphics library by Alec Jacobson et. al.
64 | * [Meshlab](http://www.meshlab.net/) for mesh rendering and minor post-processing.
65 |
--------------------------------------------------------------------------------
/define_path.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # variables to define by user
4 | PROJ_DIR=/home/maxjiang/Codes/project_sdfgan/sdfgan_data
5 | JOB_NAME=job6-plane
6 | DATASET=synset_02958343_car
7 |
8 | # export variables
9 | export PROJ_DIR
10 | export JOB_NAME
11 |
--------------------------------------------------------------------------------
/images/car_15_small.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/SDFGAN/8ad62e13505a5d6652699f737ff045834e0246e3/images/car_15_small.gif
--------------------------------------------------------------------------------
/images/car_sample.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/SDFGAN/8ad62e13505a5d6652699f737ff045834e0246e3/images/car_sample.stl
--------------------------------------------------------------------------------
/images/furniture-render.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/SDFGAN/8ad62e13505a5d6652699f737ff045834e0246e3/images/furniture-render.gif
--------------------------------------------------------------------------------
/images/morph.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxjiang93/SDFGAN/8ad62e13505a5d6652699f737ff045834e0246e3/images/morph.gif
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from model_sdfgan import SDFGAN
2 | from model_pix2pix import Pix2Pix
3 | from utils import *
4 | from ops import batch_lowpass, batch_mirr
5 | import shutil
6 |
7 | import tensorflow as tf
8 |
9 | # common training flags
10 | flags = tf.app.flags
11 | flags.DEFINE_boolean("is_train", False, "True for training, False for testing [False]")
12 | flags.DEFINE_string("model", "sdfgan", "Model to train. Choice of 'sdfgan' or 'pix2pix' [sdfgan]")
13 | flags.DEFINE_boolean("is_new", False, "True for training from scratch, deleting original file [False]")
14 | flags.DEFINE_integer("epoch", 100, "Epoch to train [100]")
15 | flags.DEFINE_float("d_learning_rate", 0.0002, "Learning rate of discrim for adam [0.0002]")
16 | flags.DEFINE_float("g_learning_rate", 0.0005, "Learning rate of gen for adam [0.0005]")
17 | flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
18 | flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")
19 | flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
20 | flags.DEFINE_integer("image_depth", 64, "The size of sdf field to use. [64]")
21 | flags.DEFINE_integer("image_height", None, "The size of sdf to use. If None, same value as image_depth [None]")
22 | flags.DEFINE_integer("image_width", None, "The size of sdf to use. If None, same value as image_depth [None]")
23 | flags.DEFINE_integer("c_dim", 1, "Number of channels. [1]")
24 | flags.DEFINE_string("dataset", "shapenet", "The name of dataset [shapenet]")
25 | flags.DEFINE_string("input_fname_pattern", "*.npy", "Glob pattern of filename of input sdf [*]")
26 | flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
27 | flags.DEFINE_string("dataset_dir", "data", "Directory name to read the input training data [data]")
28 | flags.DEFINE_string("log_dir", "logs", "Directory name to save the log files [logs]")
29 | flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
30 | flags.DEFINE_integer("num_gpus", 1, "Number of GPUs to use [1]")
31 | flags.DEFINE_integer("random_seed", 1, "Random seed [1]")
32 | flags.DEFINE_integer("save_interval", 200, "Interval in steps for saving model and samples [200]")
33 |
34 | # pix2pix specific training flags
35 | flags.DEFINE_integer("gan_weight", 1, "GAN weight in generator loss function. [1]")
36 | flags.DEFINE_integer("l1_weight", 100, "L1 weight in generator loss function. [100]")
37 |
38 | # testing flags
39 | flags.DEFINE_integer("sample_num", 16, "Number of samples. [16]")
40 | flags.DEFINE_string("test_from_input_path", None, "Path to test inputs. [None]")
41 |
42 | FLAGS = flags.FLAGS
43 |
44 | def main(_):
45 | pp.pprint(flags.FLAGS.__flags)
46 |
47 | if FLAGS.image_height is None:
48 | FLAGS.image_height = FLAGS.image_depth
49 | if FLAGS.image_width is None:
50 | FLAGS.image_width = FLAGS.image_depth
51 |
52 | if FLAGS.is_new:
53 | if os.path.exists(FLAGS.checkpoint_dir):
54 | shutil.rmtree(FLAGS.checkpoint_dir)
55 | if os.path.exists(FLAGS.sample_dir):
56 | shutil.rmtree(FLAGS.sample_dir)
57 | if os.path.exists(FLAGS.log_dir):
58 | shutil.rmtree(FLAGS.log_dir)
59 |
60 | if not os.path.exists(FLAGS.checkpoint_dir):
61 | os.makedirs(FLAGS.checkpoint_dir)
62 | if not os.path.exists(FLAGS.sample_dir):
63 | os.makedirs(FLAGS.sample_dir)
64 |
65 | # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
66 | run_config = tf.ConfigProto()
67 | run_config.gpu_options.allow_growth = True
68 |
69 | # set random seed
70 | tf.set_random_seed(FLAGS.random_seed)
71 | np.random.seed(FLAGS.random_seed)
72 |
73 | init_sdfgan = lambda sess: SDFGAN(
74 | sess,
75 | image_depth=FLAGS.image_depth,
76 | image_height=FLAGS.image_height,
77 | image_width=FLAGS.image_width,
78 | batch_size=FLAGS.batch_size,
79 | sample_num=FLAGS.batch_size,
80 | c_dim=FLAGS.c_dim,
81 | dataset_name=FLAGS.dataset,
82 | input_fname_pattern=FLAGS.input_fname_pattern,
83 | checkpoint_dir=FLAGS.checkpoint_dir,
84 | dataset_dir=FLAGS.dataset_dir,
85 | log_dir=FLAGS.log_dir,
86 | sample_dir=FLAGS.sample_dir,
87 | num_gpus=FLAGS.num_gpus,
88 | save_interval=FLAGS.save_interval)
89 |
90 | init_pix2pix = lambda sess: Pix2Pix(
91 | sess,
92 | image_depth=FLAGS.image_depth,
93 | image_height=FLAGS.image_height,
94 | image_width=FLAGS.image_width,
95 | batch_size=FLAGS.batch_size,
96 | sample_num=FLAGS.sample_num,
97 | gan_weight=FLAGS.gan_weight,
98 | l1_weight=FLAGS.l1_weight,
99 | c_dim=FLAGS.c_dim,
100 | dataset_name=FLAGS.dataset,
101 | input_fname_pattern=FLAGS.input_fname_pattern,
102 | checkpoint_dir=FLAGS.checkpoint_dir,
103 | dataset_dir=FLAGS.dataset_dir,
104 | log_dir=FLAGS.log_dir,
105 | sample_dir=FLAGS.sample_dir,
106 | num_gpus=FLAGS.num_gpus,
107 | save_interval=FLAGS.save_interval)
108 |
109 | # Train SDFGAN
110 | if FLAGS.model == 'sdfgan' and FLAGS.is_train:
111 | with tf.Session(config=run_config) as sess:
112 | sdfgan = init_sdfgan(sess)
113 | show_all_variables()
114 | if not os.path.exists(os.path.join(FLAGS.sample_dir, "sdfgan_sample")):
115 | os.makedirs(os.path.join(FLAGS.sample_dir, "sdfgan_sample"))
116 | if not os.path.exists(os.path.join(FLAGS.sample_dir, "sdfgan_log")):
117 | os.makedirs(os.path.join(FLAGS.log_dir, "sdfgan_log"))
118 | sdfgan.train(FLAGS)
119 |
120 | # Train Pix2Pix
121 | elif FLAGS.model == 'pix2pix' and FLAGS.is_train:
122 | with tf.Session(config=run_config) as sess:
123 | pix2pix = init_pix2pix(sess)
124 | show_all_variables()
125 | if not os.path.exists(os.path.join(FLAGS.sample_dir, "pix2pix_sample")):
126 | os.makedirs(os.path.join(FLAGS.sample_dir, "pix2pix_sample"))
127 | if not os.path.exists(os.path.join(FLAGS.log_dir, "pix2pix_log")):
128 | os.makedirs(os.path.join(FLAGS.log_dir, "pix2pix_log"))
129 | pix2pix.train(FLAGS)
130 |
131 | # Test entire network
132 | elif not FLAGS.is_train:
133 | # load sdfgan network to create new samples if necessary
134 | if FLAGS.test_from_input_path is None:
135 | with tf.Session(config=run_config) as sess_0:
136 | sdfgan = init_sdfgan(sess_0)
137 | show_all_variables()
138 | # generate a sample from sdfgan first
139 | if not sdfgan.load(FLAGS.checkpoint_dir):
140 | raise Exception("[!] Could not load SDFGAN model. Train first, then run test mode.")
141 | FLAGS.test_from_input_path = create_sdfgan_samples(sess_0, sdfgan, FLAGS)
142 | print("Saving intermediate unprocessed samples to {0}".format(FLAGS.test_from_input_path))
143 | sess_0.close()
144 |
145 | # post-process samples (low-pass filter & mirroring)
146 | sdf = np.squeeze(np.load(FLAGS.test_from_input_path), axis=-1)
147 | sdf_lf = batch_mirr(batch_lowpass(sdf))
148 |
149 | # create and save final samples
150 | tf.reset_default_graph()
151 | with tf.Session(config=run_config) as sess_1:
152 | pix2pix = init_pix2pix(sess_1)
153 | show_all_variables()
154 | if not pix2pix.load(FLAGS.checkpoint_dir):
155 | raise Exception("[!] Could not load Pix2Pix model. Train first, then run test mode.")
156 | sdf_hf = np.squeeze(create_pix2pix_samples(sess_1, pix2pix, sdf_lf), axis=-1)
157 | sdf_all = sdf_lf + sdf_hf
158 | sdf_save = np.concatenate((np.expand_dims(sdf_all, axis=0),
159 | np.expand_dims(sdf_lf, axis=0),
160 | np.expand_dims(sdf_hf, axis=0)), axis=0)
161 | fname = os.path.join(FLAGS.sample_dir, "full_final_samples.npy")
162 | np.save(fname, sdf_save)
163 | print("Saving final full samples to {0}, "
164 | "shape: {1} (combine_freq/low_freq/hi_freq, sample_num, dim0, dim1, dim2)"
165 | .format(fname, sdf_save.shape))
166 | sess_1.close()
167 |
168 | else:
169 | raise Exception("[!] Model must be 'sdfgan' or 'pix2pix'.")
170 |
171 |
172 | if __name__ == '__main__':
173 | tf.app.run()
174 |
--------------------------------------------------------------------------------
/model_pix2pix.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import time
3 | from glob import glob
4 | import math
5 |
6 | from ops import *
7 | from utils import *
8 |
9 | EPS = 1e-12
10 |
11 |
12 | class Pix2Pix(object):
13 | def __init__(self, sess, image_depth=64, image_height=64, image_width=64,
14 | batch_size=64, sample_num=64, gf_dim=64, df_dim=64, gan_weight=1, l1_weight=100,
15 | c_dim=1, dataset_name='shapenet_freqsplit', num_gpus=1, save_interval=200,
16 | input_fname_pattern='*.npy', checkpoint_dir=None, dataset_dir=None, log_dir=None, sample_dir=None):
17 | """
18 |
19 | Args:
20 | sess: TensorFlow session
21 | batch_size: The size of batch. Should be specified before training.
22 | gf_dim: (optional) Dimension of gen filters in first conv layer. [64]
23 | df_dim: (optional) Dimension of discrim filters in first conv layer. [64]
24 | c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3]
25 | """
26 | self.sess = sess
27 |
28 | self.batch_size = batch_size
29 | self.sample_num = sample_num
30 |
31 | self.image_depth = image_depth
32 | self.image_height = image_height
33 | self.image_width = image_width
34 |
35 | self.gf_dim = gf_dim
36 | self.df_dim = df_dim
37 |
38 | self.c_dim = c_dim
39 | self.save_inverval = save_interval
40 |
41 | self.gan_weight = gan_weight
42 | self.l1_weight = l1_weight
43 |
44 | self.num_gpus = num_gpus
45 | self.glob_batch_size = self.num_gpus * self.batch_size
46 |
47 | self.dataset_name = dataset_name
48 | self.input_fname_pattern = input_fname_pattern
49 | self.checkpoint_dir = checkpoint_dir
50 | self.dataset_dir = dataset_dir
51 | self.log_dir = os.path.join(log_dir, "pix2pix_log")
52 | self.sample_dir = os.path.join(sample_dir, "pix2pix_sample")
53 | self.build_model()
54 |
55 | def build_model(self):
56 |
57 | image_dims = [self.image_depth, self.image_height, self.image_width, self.c_dim]
58 |
59 | # input placeholders
60 | self.inputs = tf.placeholder(
61 | tf.float32, [self.glob_batch_size] + image_dims, name='input_images')
62 | self.targets = tf.placeholder(
63 | tf.float32, [self.glob_batch_size] + image_dims, name='target_images')
64 | self.sample_inputs = tf.placeholder(
65 | tf.float32, [self.sample_num] + image_dims, name='sample_inputs')
66 |
67 | self.n_eff = tf.placeholder(tf.int32, name='n_eff') # overall number of effective data points
68 |
69 | # initialize global lists
70 | self.G = [None] * self.num_gpus
71 | self.D = [None] * self.num_gpus
72 | self.D_ = [None] * self.num_gpus
73 | self.d_losses = [None] * self.num_gpus
74 | self.g_losses = [None] * self.num_gpus
75 | self.g_losses_gan = [None] * self.num_gpus
76 | self.g_losses_l1 = [None] * self.num_gpus
77 | self.n_effs = [None] * self.num_gpus
78 |
79 | # compute using multiple gpus
80 | with tf.variable_scope(tf.get_variable_scope()) as vscope:
81 | for gpuid in xrange(self.num_gpus):
82 | with tf.device('/gpu:%d' % gpuid):
83 |
84 | # range of data for this gpu
85 | gpu_start = gpuid * self.batch_size
86 | gpu_end = (gpuid + 1) * self.batch_size
87 |
88 | # number of effective data points
89 | gpu_n_eff = tf.reduce_min([tf.reduce_max([0, self.n_eff - gpu_start]), self.batch_size])
90 |
91 | # create examples and pass through discriminator
92 | gpu_G = self.generator(self.inputs[gpu_start:gpu_end])
93 | gpu_D = self.discriminator(self.inputs[gpu_start:gpu_end], self.targets[gpu_start:gpu_end]) # real
94 | gpu_D_ = self.discriminator(self.inputs[gpu_start:gpu_end], gpu_G, reuse=True) # fake pairs
95 |
96 | # compute discriminator loss
97 | gpu_d_loss = tf.reduce_mean(-(tf.log(gpu_D[:gpu_n_eff] + EPS)
98 | + tf.log(1 - gpu_D_[:gpu_n_eff] + EPS)))
99 |
100 | # compute generator loss
101 | gpu_g_loss_gan = tf.reduce_mean(-tf.log(gpu_D[:gpu_n_eff] + EPS))
102 | gpu_g_loss_l1 = tf.reduce_mean(tf.abs(self.targets[gpu_start:gpu_end] - gpu_G))
103 | gpu_g_loss = gpu_g_loss_gan * self.gan_weight + gpu_g_loss_l1 * self.l1_weight
104 |
105 | # add gpu-wise data to global list
106 | self.G[gpuid] = gpu_G
107 | self.D[gpuid] = gpu_D
108 | self.D_[gpuid] = gpu_D_
109 | self.d_losses[gpuid] = gpu_d_loss
110 | self.g_losses[gpuid] = gpu_g_loss
111 | self.g_losses_gan[gpuid] = gpu_g_loss_gan
112 | self.g_losses_l1[gpuid] = gpu_g_loss_l1
113 | self.n_effs[gpuid] = gpu_n_eff
114 |
115 | # Reuse variables for the next gpu
116 | tf.get_variable_scope().reuse_variables()
117 |
118 | # concatenate across GPUs
119 | self.D = tf.concat(self.D, axis=0)
120 | self.D_ = tf.concat(self.D_, axis=0)
121 | self.G = tf.concat(self.G, axis=0)
122 | weighted_d_loss = [self.d_losses[j] * tf.cast(self.n_effs[j], tf.float32)
123 | / tf.cast(self.n_eff, tf.float32) for j in range(self.num_gpus)]
124 | weighted_g_loss = [self.g_losses[j] * tf.cast(self.n_effs[j], tf.float32)
125 | / tf.cast(self.n_eff, tf.float32) for j in range(self.num_gpus)]
126 | weighted_g_loss_gan = [tf.cast(self.g_losses_gan[j], tf.float32) * tf.cast(self.n_effs[j], tf.float32)
127 | / tf.cast(self.n_eff, tf.float32) for j in range(self.num_gpus)]
128 | weighted_g_loss_l1 = [tf.cast(self.g_losses_l1[j], tf.float32) * tf.cast(self.n_effs[j], tf.float32)
129 | / tf.cast(self.n_eff, tf.float32) for j in range(self.num_gpus)]
130 |
131 | self.d_loss = tf.reduce_sum(weighted_d_loss, axis=0)
132 | self.g_loss = tf.reduce_sum(weighted_g_loss, axis=0)
133 | self.g_loss_gan = tf.reduce_sum(weighted_g_loss_gan, axis=0)
134 | self.g_loss_l1 = tf.reduce_sum(weighted_g_loss_l1, axis=0)
135 |
136 | # summarize variables
137 | self.d_sum = histogram_summary("d", self.D)
138 | self.d__sum = histogram_summary("d_", self.D_)
139 | self.g_sum = image_summary("G", self.G[:, int(self.image_depth / 2), :, :])
140 |
141 | self.d_loss_sum = scalar_summary("d_loss", self.d_loss)
142 |
143 | self.g_loss_sum = scalar_summary("g_loss", self.g_loss)
144 | self.g_loss_gan_sum = scalar_summary("g_loss_gan", self.g_loss_gan)
145 | self.g_loss_l1_sum = scalar_summary("g_loss_l1", self.g_loss_l1)
146 |
147 | # define trainable variables for generator and discriminator
148 | t_vars = tf.trainable_variables()
149 | self.d_vars = [var for var in t_vars if 'd_' in var.name]
150 | self.g_vars = [var for var in t_vars if 'g_' in var.name]
151 |
152 | self.sampler = self.sampler(self.sample_inputs)
153 | self.saver = tf.train.Saver()
154 |
155 | def train(self, config):
156 | """Train SFDGAN"""
157 | data = glob(os.path.join(self.dataset_dir, config.dataset, self.input_fname_pattern))
158 | np.random.shuffle(data)
159 |
160 | # define optimization operation
161 | d_opt = tf.train.AdamOptimizer(config.d_learning_rate, beta1=config.beta1)
162 | g_opt = tf.train.AdamOptimizer(config.g_learning_rate, beta1=config.beta1)
163 |
164 | # create list of grads from different gpus
165 | global_d_grads_vars = [None] * self.num_gpus
166 | global_g_grads_vars = [None] * self.num_gpus
167 |
168 | # compute d gradients
169 | with tf.variable_scope(tf.get_variable_scope()):
170 | for gpuid in xrange(self.num_gpus):
171 | with tf.device('/gpu:%d' % gpuid):
172 | gpu_d_grads_vars = d_opt.compute_gradients(loss=self.d_losses[gpuid], var_list=self.d_vars)
173 | global_d_grads_vars[gpuid] = gpu_d_grads_vars
174 |
175 | # compute g gradients
176 | with tf.variable_scope(tf.get_variable_scope()):
177 | for gpuid in xrange(self.num_gpus):
178 | with tf.device('/gpu:%d' % gpuid):
179 | gpu_g_grads_vars = g_opt.compute_gradients(loss=self.g_losses[gpuid], var_list=self.g_vars)
180 | global_g_grads_vars[gpuid] = gpu_g_grads_vars
181 |
182 | # average gradients across gpus and apply gradients
183 | d_grads_vars = average_gradients(global_d_grads_vars)
184 | g_grads_vars = average_gradients(global_g_grads_vars)
185 | d_optim = d_opt.apply_gradients(d_grads_vars)
186 | g_optim = g_opt.apply_gradients(g_grads_vars)
187 |
188 | # compatibility across tf versions
189 | try:
190 | tf.global_variables_initializer().run()
191 | except:
192 | tf.initialize_all_variables().run()
193 |
194 | self.g_sum = merge_summary([self.d__sum, self.g_loss_gan_sum, self.g_loss_l1_sum, self.g_loss_sum, self.g_sum])
195 | self.d_sum = merge_summary([self.d_sum, self.d_loss_sum])
196 |
197 | self.writer = SummaryWriter(self.log_dir, self.sess.graph)
198 | sample_files = data[0:self.sample_num]
199 |
200 | sample_inputs = [np.load(sample_file)[0, :, :, :] for sample_file in sample_files]
201 | sample_targets = [np.load(sample_file)[1, :, :, :] for sample_file in sample_files]
202 | sample_in = np.array(sample_inputs).astype(np.float32)[:, :, :, :, None]
203 | sample_tg = np.array(sample_targets).astype(np.float32)[:, :, :, :, None]
204 |
205 | counter = 0
206 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
207 | if could_load:
208 | counter = checkpoint_counter
209 | print(" [*] Load SUCCESS")
210 | else:
211 | print(" [!] Load failed...")
212 |
213 | batch_idxs = int(math.ceil(min(len(data), config.train_size) / self.glob_batch_size)) - 1
214 | total_steps = config.epoch * batch_idxs
215 | prev_time = -np.inf
216 |
217 | for epoch in xrange(config.epoch):
218 | # shuffle data before training in each epoch
219 | np.random.shuffle(data)
220 | for idx in xrange(0, batch_idxs):
221 | glob_batch_files = data[idx * self.glob_batch_size:(idx + 1) * self.glob_batch_size]
222 | glob_batch_inputs = [
223 | np.load(batch_file)[0, :, :, :] for batch_file in glob_batch_files]
224 | glob_batch_targets = [
225 | np.load(batch_file)[1, :, :, :] for batch_file in glob_batch_files]
226 | glob_batch_in = np.array(glob_batch_inputs).astype(np.float32)[:, :, :, :, None]
227 | glob_batch_tg = np.array(glob_batch_targets).astype(np.float32)[:, :, :, :, None]
228 |
229 | n_eff = len(glob_batch_files)
230 |
231 | # Pad zeros if effective batch size is smaller than global batch size
232 | if n_eff != self.glob_batch_size:
233 | glob_batch_in = pad_glob_batch(glob_batch_in, self.glob_batch_size)
234 | glob_batch_tg = pad_glob_batch(glob_batch_tg, self.glob_batch_size)
235 |
236 | # Update D network
237 | _, summary_str = self.sess.run([d_optim, self.d_sum],
238 | feed_dict={self.inputs: glob_batch_in,
239 | self.targets: glob_batch_tg,
240 | self.n_eff: n_eff})
241 | self.writer.add_summary(summary_str, counter)
242 |
243 | # Update G network
244 | _, summary_str = self.sess.run([g_optim, self.g_sum],
245 | feed_dict={self.inputs: glob_batch_in,
246 | self.targets: glob_batch_tg,
247 | self.n_eff: n_eff})
248 | self.writer.add_summary(summary_str, counter)
249 |
250 | # Compute last batch accuracy and losses
251 | lossD, lossG = self.sess.run([self.d_loss, self.g_loss],
252 | feed_dict={self.inputs: glob_batch_in,
253 | self.targets: glob_batch_tg,
254 | self.n_eff: n_eff})
255 |
256 | # get time
257 | now_time = time.time()
258 | time_per_iter = now_time - prev_time
259 | prev_time = now_time
260 | eta = (total_steps - counter + checkpoint_counter) * time_per_iter
261 | counter += 1
262 |
263 | try:
264 | timestr = time.strftime("%H:%M:%S", time.gmtime(eta))
265 | except:
266 | timestr = '?:?:?'
267 |
268 | print("Epoch:[%3d] [%3d/%3d] Iter:[%5d] eta(h:m:s): %s, d_loss: %.8f, g_loss: %.8f"
269 | % (epoch, idx, batch_idxs, counter, timestr, lossD, lossG))
270 |
271 | # save checkpoint and samples every save_interval steps
272 | if np.mod(counter, self.save_inverval) == 1:
273 | sample_gen = self.sess.run(self.sampler, feed_dict={self.sample_inputs: sample_in})
274 | sample = np.concatenate((np.expand_dims(sample_in, axis=0), # sample_num x 64 x 64 x 64 x 1
275 | np.expand_dims(sample_tg, axis=0), # sample_num x 64 x 64 x 64 x 1
276 | np.expand_dims(sample_gen, axis=0)), axis=0) # sample_num x 64 x 64 x 64 x 1
277 | np.save(self.sample_dir+'/sample_{:05d}.npy'
278 | .format(counter), sample)
279 | print("[Sample] Iter {0}, saving sample size of {1}, saving checkpoint."
280 | .format(counter, self.sample_num))
281 | self.save(config.checkpoint_dir, counter)
282 |
283 | # save last checkpoint
284 | sample_gen = self.sess.run(self.sampler, feed_dict={self.sample_inputs: sample_in})
285 | sample = np.concatenate((np.expand_dims(sample_in[:, :, :, :, 0], axis=0), # sample_num x 64 x 64 x 64
286 | np.expand_dims(sample_tg[:, :, :, :, 0], axis=0), # sample_num x 64 x 64 x 64
287 | np.expand_dims(sample_gen[:, :, :, :, 0], axis=0)), axis=0) # sample_num x 64 x 64 x 64
288 | np.save(self.sample_dir+'/sample_{:05d}.npy'
289 | .format(counter), sample)
290 | print("[Sample] Iter {0}, saving sample size of {1}, saving checkpoint."
291 | .format(counter, self.sample_num))
292 | self.save(config.checkpoint_dir, counter)
293 |
294 | print("[!] Training of Pix2Pix Network Complete.")
295 |
296 | def discriminator(self, discrim_inputs, discrim_targets, reuse=False):
297 | with tf.variable_scope("pix2pix_discriminator") as scope:
298 | if reuse:
299 | scope.reuse_variables()
300 | n_layers = 3
301 | layers = []
302 |
303 | # 2x [batch, depth, height, width, in_channels] => [batch, depth, height, width, in_channels * 2]
304 | input = tf.concat([discrim_inputs, discrim_targets], axis=-1)
305 |
306 | # layer_1: [batch, 64, 64, 64, in_channels * 2] => [batch, 32, 32, 32, df_dim]
307 | convolved = conv3d(input, self.df_dim, name='d_h0_conv')
308 | rectified = lrelu(convolved, 0.2)
309 | layers.append(rectified)
310 |
311 | # layer_2: [batch, 32, 32, 32, df_dim] => [batch, 16, 16, 16, df_dim * 2]
312 | # layer_3: [batch, 16, 16, 16, df_dim * 2] => [batch, 8, 8, 8, df_dim * 4]
313 | # layer_4: [batch, 8, 8, 8, df_dim * 4] => [batch, 7, 7, 7, df_dim * 8]
314 | for i in range(n_layers):
315 | # with tf.variable_scope("layer_%d" % (len(layers) + 1)):
316 | out_channels = self.df_dim * min(2 ** (i + 1), 8)
317 | stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1
318 | convolved = conv3d(layers[-1], out_channels,
319 | d_d=stride, d_h=stride, d_w=stride, name='d_h%d_conv' % (len(layers)))
320 | normalized = batchnorm(convolved, name='d_h%d_bn' % (len(layers)))
321 | rectified = lrelu(normalized, 0.2)
322 | layers.append(rectified)
323 |
324 | # layer_5: [batch, 7, 7, df_dim * 8] => [batch, 6, 6, 1]
325 | # with tf.variable_scope("layer_%d" % (len(layers) + 1)):
326 | convolved = conv3d(rectified, output_dim=1, d_d=1, d_h=1, d_w=1, name='d_h%d_conv' % (len(layers)))
327 | output = tf.sigmoid(convolved)
328 | layers.append(output)
329 |
330 | return layers[-1]
331 |
332 | def generator(self, generator_inputs):
333 | with tf.variable_scope("pix2pix_generator") as scope:
334 | layers = []
335 |
336 | # encoder_1: [batch, 64, 64, 64, in_channels] => [batch, 32, 32, gf_dim]
337 | output = conv3d(generator_inputs, self.gf_dim, name="g_enc_conv_0")
338 | layers.append(output)
339 |
340 | layer_specs = [
341 | self.gf_dim * 2, # encoder_2: [batch, 32, 32, 32, gf_dim] => [batch, 16, 16, 16, gf_dim * 2]
342 | self.gf_dim * 4, # encoder_3: [batch, 16, 16, 16, gf_dim * 2] => [batch, 8, 8, 8, gf_dim * 4]
343 | self.gf_dim * 8, # encoder_4: [batch, 8, 8, 8, gf_dim * 2] => [batch, 4, 4, 4, gf_dim * 4]
344 | self.gf_dim * 8, # encoder_5: [batch, 4, 4, 4, gf_dim * 8] => [batch, 2, 2, 2, gf_dim * 8]
345 | self.gf_dim * 8, # encoder_6: [batch, 2, 2, 2, gf_dim * 8] => [batch, 1, 1, 1, gf_dim * 8]
346 | ]
347 |
348 | for out_channels in layer_specs:
349 | rectified = lrelu(layers[-1], 0.2)
350 | # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
351 | convolved = conv3d(rectified, out_channels, name="g_enc_cov_%d" % (len(layers)))
352 | output = batchnorm(convolved, name="g_enc_bn_%d" % (len(layers)))
353 | layers.append(output)
354 |
355 | layer_specs = [
356 | (self.gf_dim * 8, 0.5), # decoder_6: [batch,1,1,1,gf_dim * 8] => [batch,2,2,2,gf_dim * 8 * 2]
357 | (self.gf_dim * 8, 0.5), # decoder_5: [batch,2,2,2,gf_dim * 8 * 2] => [batch,4 4,4,gf_dim * 8 * 2]
358 | (self.gf_dim * 4, 0.5), # decoder_4: [batch,4,4,4,gf_dim * 8 * 2] => [batch,8,8,8,gf_dim * 4 * 2]
359 | (self.gf_dim * 2, 0.0), # decoder_3: [batch,8,8,8,gf_dim * 4 * 2] => [batch,16,16,16,gf_dim * 2 * 2]
360 | (self.gf_dim * 1, 0.0), # decoder_2: [batch,16,16,16,gf_dim * 2 * 2] => [batch,32,32,32,gf_dim * 2]
361 | ]
362 |
363 | num_encoder_layers = len(layers)
364 | for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
365 | skip_layer = num_encoder_layers - decoder_layer - 1
366 | if decoder_layer == 0:
367 | # first decoder layer doesn't have skip connections
368 | # since it is directly connected to the skip_layer
369 | input = layers[-1]
370 | else:
371 | input = tf.concat([layers[-1], layers[skip_layer]], axis=-1)
372 |
373 | rectified = tf.nn.relu(input)
374 | # [batch, in_depth, in_height, in_width, in_channels]
375 | # => [batch, in_depth, in_height*2, in_width*2, out_channels]
376 | in_dims = rectified.get_shape().as_list()
377 | out_dims = [in_dims[0], in_dims[1] * 2, in_dims[2] * 2, in_dims[3] * 2, out_channels]
378 | output = deconv3d(rectified, out_dims, name="g_dec_conv_%d" % skip_layer)
379 | output = batchnorm(output, name="g_dec_bn_%d" % skip_layer)
380 |
381 | if dropout > 0.0:
382 | output = tf.nn.dropout(output, keep_prob=1 - dropout)
383 |
384 | layers.append(output)
385 |
386 | # decoder_1: [batch, 32, 32, 32, gf_dim * 2] => [batch, 64, 64, 64, 1]
387 | input = tf.concat([layers[-1], layers[0]], axis=-1)
388 | rectified = tf.nn.relu(input)
389 | in_dims = rectified.get_shape().as_list()
390 | out_dims = [in_dims[0], in_dims[1] * 2, in_dims[2] * 2, in_dims[3] * 2, 1]
391 | output = deconv3d(rectified, out_dims, name="g_dec_conv_0")
392 | output = tf.tanh(output)
393 | layers.append(output)
394 |
395 | return layers[-1]
396 |
397 | def sampler(self, generator_inputs):
398 | with tf.variable_scope("pix2pix_generator") as scope:
399 | scope.reuse_variables()
400 |
401 | layers = []
402 |
403 | # encoder_1: [batch, 64, 64, 64, in_channels] => [batch, 32, 32, gf_dim]
404 | output = conv3d(generator_inputs, self.gf_dim, name="g_enc_conv_0")
405 | layers.append(output)
406 |
407 | layer_specs = [
408 | self.gf_dim * 2, # encoder_2: [batch, 32, 32, 32, gf_dim] => [batch, 16, 16, 16, gf_dim * 2]
409 | self.gf_dim * 4, # encoder_3: [batch, 16, 16, 16, gf_dim * 2] => [batch, 8, 8, 8, gf_dim * 4]
410 | self.gf_dim * 8, # encoder_4: [batch, 8, 8, 8, gf_dim * 2] => [batch, 4, 4, 4, gf_dim * 4]
411 | self.gf_dim * 8, # encoder_5: [batch, 4, 4, 4, gf_dim * 8] => [batch, 2, 2, 2, gf_dim * 8]
412 | self.gf_dim * 8, # encoder_6: [batch, 2, 2, 2, gf_dim * 8] => [batch, 1, 1, 1, gf_dim * 8]
413 | ]
414 |
415 | for out_channels in layer_specs:
416 | rectified = lrelu(layers[-1], 0.2)
417 | # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
418 | convolved = conv3d(rectified, out_channels, name="g_enc_cov_%d" % (len(layers)))
419 | output = batchnorm(convolved, name="g_enc_bn_%d" % (len(layers)))
420 | layers.append(output)
421 |
422 | layer_specs = [
423 | (self.gf_dim * 8, 0.5), # decoder_6: [batch,1,1,1,gf_dim * 8] => [batch,2,2,2,gf_dim * 8 * 2]
424 | (self.gf_dim * 8, 0.5), # decoder_5: [batch,2,2,2,gf_dim * 8 * 2] => [batch,4 4,4,gf_dim * 8 * 2]
425 | (self.gf_dim * 4, 0.5), # decoder_4: [batch,4,4,4,gf_dim * 8 * 2] => [batch,8,8,8,gf_dim * 4 * 2]
426 | (self.gf_dim * 2, 0.0), # decoder_3: [batch,8,8,8,gf_dim * 4 * 2] => [batch,16,16,16,gf_dim * 2 * 2]
427 | (self.gf_dim * 1, 0.0), # decoder_2: [batch,16,16,16,gf_dim * 2 * 2] => [batch,32,32,32,gf_dim * 2]
428 | ]
429 |
430 | num_encoder_layers = len(layers)
431 | for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
432 | skip_layer = num_encoder_layers - decoder_layer - 1
433 | if decoder_layer == 0:
434 | # first decoder layer doesn't have skip connections
435 | # since it is directly connected to the skip_layer
436 | input = layers[-1]
437 | else:
438 | input = tf.concat([layers[-1], layers[skip_layer]], axis=-1)
439 |
440 | rectified = tf.nn.relu(input)
441 | # [batch, in_depth, in_height, in_width, in_channels]
442 | # => [batch, in_depth, in_height*2, in_width*2, out_channels]
443 | in_dims = rectified.get_shape().as_list()
444 | out_dims = [in_dims[0], in_dims[1] * 2, in_dims[2] * 2, in_dims[3] * 2, out_channels]
445 | output = deconv3d(rectified, out_dims, name="g_dec_conv_%d" % skip_layer)
446 | output = batchnorm(output, name="g_dec_bn_%d" % skip_layer)
447 |
448 | if dropout > 0.0:
449 | output = tf.nn.dropout(output, keep_prob=1 - dropout)
450 |
451 | layers.append(output)
452 |
453 | # decoder_1: [batch, 32, 32, 32, gf_dim * 2] => [batch, 64, 64, 64, 1]
454 | input = tf.concat([layers[-1], layers[0]], axis=-1)
455 | rectified = tf.nn.relu(input)
456 | in_dims = rectified.get_shape().as_list()
457 | out_dims = [in_dims[0], in_dims[1] * 2, in_dims[2] * 2, in_dims[3] * 2, 1]
458 | output = deconv3d(rectified, out_dims, name="g_dec_conv_0")
459 | output = tf.tanh(output)
460 | layers.append(output)
461 |
462 | return layers[-1]
463 |
464 | @property
465 | def model_dir(self):
466 | return "Pix2Pix_" + "{}_{}_{}_{}".format(
467 | self.dataset_name.replace("_freqsplit", ""),
468 | self.image_depth, self.image_height, self.image_width)
469 |
470 | def save(self, checkpoint_dir, step):
471 | model_name = "Pix2Pix.model"
472 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
473 |
474 | if not os.path.exists(checkpoint_dir):
475 | os.makedirs(checkpoint_dir)
476 |
477 | self.saver.save(self.sess,
478 | os.path.join(checkpoint_dir, model_name),
479 | global_step=step)
480 |
481 | def load(self, checkpoint_dir):
482 | import re
483 | print(" [*] Reading checkpoints...")
484 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
485 |
486 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
487 | if ckpt and ckpt.model_checkpoint_path:
488 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
489 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
490 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
491 | print(" [*] Success to read {}".format(ckpt_name))
492 | return True, counter
493 | else:
494 | print(" [*] Failed to find a checkpoint")
495 | return False, 0
496 |
--------------------------------------------------------------------------------
/model_sdfgan.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import time
3 | from glob import glob
4 |
5 | from ops import *
6 | from utils import *
7 |
8 |
9 | def conv_out_size_same(size, stride):
10 | return int(math.ceil(float(size) / float(stride)))
11 |
12 |
13 | class SDFGAN(object):
14 | def __init__(self, sess, num_gpus=1, image_depth=64, image_height=64, image_width=64, save_interval=200,
15 | batch_size=64, sample_num=64, z_dim=200, gf_dim=64, df_dim=64,c_dim=1, dataset_name='shapenet',
16 | input_fname_pattern='*.npy', checkpoint_dir=None, dataset_dir=None, log_dir=None, sample_dir=None):
17 | """
18 |
19 | Args:
20 | sess: TensorFlow session
21 | batch_size: The size of batch. Should be specified before training.
22 | z_dim: (optional) Dimension of dim for Z. [200]
23 | gf_dim: (optional) Dimension of gen filters in first conv layer. [64]
24 | df_dim: (optional) Dimension of discrim filters in first conv layer. [64]
25 | c_dim: (optional) Dimension of channels. For sdf input, set to 1. [3]
26 | """
27 | self.sess = sess
28 |
29 | self.batch_size = batch_size
30 | self.sample_num = sample_num
31 |
32 | self.image_depth = image_depth
33 | self.image_height = image_height
34 | self.image_width = image_width
35 |
36 | self.z_dim = z_dim
37 |
38 | self.gf_dim = gf_dim
39 | self.df_dim = df_dim
40 |
41 | self.c_dim = c_dim
42 | self.save_interval = save_interval
43 | self.num_gpus = num_gpus
44 | self.glob_batch_size = self.num_gpus * self.batch_size
45 |
46 | # batch normalization : deals with poor initialization helps gradient flow
47 | self.d_bn1 = batch_norm(name='d_bn1')
48 | self.d_bn2 = batch_norm(name='d_bn2')
49 | self.d_bn3 = batch_norm(name='d_bn3')
50 |
51 | self.g_bn0 = batch_norm(name='g_bn0')
52 | self.g_bn1 = batch_norm(name='g_bn1')
53 | self.g_bn2 = batch_norm(name='g_bn2')
54 | self.g_bn3 = batch_norm(name='g_bn3')
55 |
56 | self.dataset_name = dataset_name
57 | self.input_fname_pattern = input_fname_pattern
58 | self.checkpoint_dir = checkpoint_dir
59 | self.dataset_dir = dataset_dir
60 | self.log_dir = os.path.join(log_dir, "sdfgan_log")
61 | self.sample_dir = os.path.join(sample_dir, "sdfgan_sample")
62 | self.build_model()
63 |
64 | def build_model(self):
65 |
66 | image_dims = [self.image_depth, self.image_height, self.image_width, self.c_dim]
67 |
68 | # input placeholders
69 | self.inputs = tf.placeholder(
70 | tf.float32, [self.glob_batch_size] + image_dims, name='real_images')
71 | self.z = tf.placeholder(
72 | tf.float32, [None, self.z_dim], name='z') # latent vector
73 | self.z_sum = histogram_summary("z", self.z)
74 |
75 | self.n_eff = tf.placeholder(tf.int32, name='n_eff') # overall number of effective data points
76 |
77 | # initialize global lists
78 | self.G = [None] * self.num_gpus # generator results
79 | self.D = [None] * self.num_gpus # discriminator results for real images
80 | self.D_logits = [None] * self.num_gpus
81 | self.D_ = [None] * self.num_gpus # discriminator results for fake images
82 | self.D_logits_ = [None] * self.num_gpus
83 | self.d_loss_real = [None] * self.num_gpus
84 | self.d_loss_fake = [None] * self.num_gpus
85 | self.d_losses = [None] * self.num_gpus
86 | self.g_losses = [None] * self.num_gpus
87 | self.d_accus = [None] * self.num_gpus # discriminator accuracy
88 | self.n_effs = [None] * self.num_gpus
89 |
90 | # compute using multiple gpus
91 | with tf.variable_scope(tf.get_variable_scope()) as vscope:
92 | for gpuid in xrange(self.num_gpus):
93 | with tf.device('/gpu:%d' % gpuid):
94 |
95 | # range of data for this gpu
96 | gpu_start = gpuid * self.batch_size
97 | gpu_end = (gpuid + 1) * self.batch_size
98 |
99 | # number of effective data points
100 | gpu_n_eff = tf.reduce_min([tf.reduce_max([0, self.n_eff - gpu_start]), self.batch_size])
101 |
102 | # create examples and pass through discriminator
103 | gpu_G = self.generator(self.z[gpu_start:gpu_end])
104 | gpu_D, gpu_D_logits = self.discriminator(self.inputs[gpu_start:gpu_end])
105 | gpu_D_, gpu_D_logits_ = self.discriminator(gpu_G, reuse=True)
106 |
107 | # compatibility across different tf versions
108 | def sigmoid_cross_entropy_with_logits(x, y):
109 | try:
110 | return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)
111 | except:
112 | return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, targets=y)
113 |
114 | # compute loss and accuracy
115 | gpu_d_loss_real = tf.reduce_mean(
116 | sigmoid_cross_entropy_with_logits(gpu_D_logits[:gpu_n_eff], tf.ones_like(gpu_D[:gpu_n_eff])))
117 | gpu_d_loss_fake = tf.reduce_mean(
118 | sigmoid_cross_entropy_with_logits(gpu_D_logits_[:gpu_n_eff], tf.zeros_like(gpu_D_[:gpu_n_eff])))
119 | gpu_g_loss_gen = tf.reduce_mean(
120 | sigmoid_cross_entropy_with_logits(gpu_D_logits_[:gpu_n_eff], tf.ones_like(gpu_D_[:gpu_n_eff])))
121 | gpu_d_loss = gpu_d_loss_real + gpu_d_loss_fake
122 | gpu_d_accu_real = tf.reduce_sum(tf.cast(gpu_D[:gpu_n_eff] > .5, tf.int32)) / gpu_D.get_shape()[0]
123 | gpu_d_accu_fake = tf.reduce_sum(tf.cast(gpu_D_[:gpu_n_eff] < .5, tf.int32)) / gpu_D_.get_shape()[0]
124 | gpu_d_accu = (gpu_d_accu_real + gpu_d_accu_fake) / 2
125 |
126 | # combined generator loss
127 | gpu_g_loss = gpu_g_loss_gen
128 |
129 | # add gpu-wise data to global list
130 | self.G[gpuid] = gpu_G
131 | self.D[gpuid] = gpu_D
132 | self.D_[gpuid] = gpu_D_
133 | self.D_logits[gpuid] = gpu_D_logits
134 | self.D_logits_[gpuid] = gpu_D_logits_
135 | self.d_loss_real[gpuid] = gpu_d_loss_real
136 | self.d_loss_fake[gpuid] = gpu_d_loss_fake
137 | self.d_losses[gpuid] = gpu_d_loss
138 | self.g_losses[gpuid] = gpu_g_loss
139 | self.d_accus[gpuid] = gpu_d_accu
140 | self.n_effs[gpuid] = gpu_n_eff
141 |
142 | # Reuse variables for the next gpu
143 | tf.get_variable_scope().reuse_variables()
144 |
145 | # concatenate across GPUs
146 | self.D = tf.concat(self.D, axis=0)
147 | self.D_ = tf.concat(self.D_, axis=0)
148 | self.G = tf.concat(self.G, axis=0)
149 | weighted_d_loss_real = [self.d_loss_real[j] * tf.cast(self.n_effs[j], tf.float32)
150 | / tf.cast(self.n_eff, tf.float32) for j in range(self.num_gpus)]
151 | weighted_d_loss_fake = [self.d_loss_fake[j] * tf.cast(self.n_effs[j], tf.float32)
152 | / tf.cast(self.n_eff, tf.float32) for j in range(self.num_gpus)]
153 | weighted_d_loss = [self.d_losses[j] * tf.cast(self.n_effs[j], tf.float32)
154 | / tf.cast(self.n_eff, tf.float32) for j in range(self.num_gpus)]
155 | weighted_g_loss = [self.g_losses[j] * tf.cast(self.n_effs[j], tf.float32)
156 | / tf.cast(self.n_eff, tf.float32) for j in range(self.num_gpus)]
157 | weighted_d_accu = [tf.cast(self.d_accus[j], tf.float32) * tf.cast(self.n_effs[j], tf.float32)
158 | / tf.cast(self.n_eff, tf.float32) for j in range(self.num_gpus)]
159 |
160 | self.d_loss_real = tf.reduce_sum(weighted_d_loss_real, axis=0)
161 | self.d_loss_fake = tf.reduce_sum(weighted_d_loss_fake, axis=0)
162 | self.d_loss = tf.reduce_sum(weighted_d_loss, axis=0)
163 | self.g_loss = tf.reduce_sum(weighted_g_loss, axis=0)
164 | self.d_accu = tf.reduce_sum(weighted_d_accu, axis=0)
165 |
166 | # summarize variables
167 | self.d_sum = histogram_summary("d", self.D)
168 | self.d__sum = histogram_summary("d_", self.D_)
169 | self.g_sum = image_summary("G", self.G[:, int(self.image_depth / 2), :, :])
170 |
171 | self.d_loss_real_sum = scalar_summary("d_loss_real", self.d_loss_real)
172 | self.d_loss_fake_sum = scalar_summary("d_loss_fake", self.d_loss_fake)
173 | self.d_loss_sum = scalar_summary("d_loss", self.d_loss)
174 |
175 | self.g_loss_sum = scalar_summary("g_loss", self.g_loss)
176 | self.d_accu_sum = scalar_summary("d_accu", self.d_accu)
177 |
178 | # define trainable variables for generator and discriminator
179 | t_vars = tf.trainable_variables()
180 | self.d_vars = [var for var in t_vars if 'd_' in var.name]
181 | self.g_vars = [var for var in t_vars if 'g_' in var.name]
182 |
183 | self.sampler = self.generator(self.z, reuse=True, is_train=False)
184 | self.saver = tf.train.Saver()
185 |
186 | def train(self, config):
187 | """Train SFDGAN"""
188 | data = glob(os.path.join(self.dataset_dir, config.dataset, self.input_fname_pattern))
189 | np.random.shuffle(data)
190 |
191 | # define optimization operation
192 | d_opt = tf.train.AdamOptimizer(config.d_learning_rate, beta1=config.beta1)
193 | g_opt = tf.train.AdamOptimizer(config.g_learning_rate, beta1=config.beta1)
194 |
195 | # create list of grads from different gpus
196 | global_d_grads_vars = [None] * self.num_gpus
197 | global_g_grads_vars = [None] * self.num_gpus
198 |
199 | # compute d gradients
200 | with tf.variable_scope(tf.get_variable_scope()):
201 | for gpuid in xrange(self.num_gpus):
202 | with tf.device('/gpu:%d' % gpuid):
203 | gpu_d_grads_vars = d_opt.compute_gradients(loss=self.d_losses[gpuid], var_list=self.d_vars)
204 | global_d_grads_vars[gpuid] = gpu_d_grads_vars
205 |
206 | # compute g gradients
207 | with tf.variable_scope(tf.get_variable_scope()):
208 | for gpuid in xrange(self.num_gpus):
209 | with tf.device('/gpu:%d' % gpuid):
210 | gpu_g_grads_vars = g_opt.compute_gradients(loss=self.g_losses[gpuid], var_list=self.g_vars)
211 | global_g_grads_vars[gpuid] = gpu_g_grads_vars
212 |
213 | # average gradients across gpus and apply gradients
214 | d_grads_vars = average_gradients(global_d_grads_vars)
215 | g_grads_vars = average_gradients(global_g_grads_vars)
216 | d_optim = d_opt.apply_gradients(d_grads_vars)
217 | g_optim = g_opt.apply_gradients(g_grads_vars)
218 |
219 | # compatibility across tf versions
220 | try:
221 | tf.global_variables_initializer().run()
222 | except:
223 | tf.initialize_all_variables().run()
224 |
225 | self.g_sum = merge_summary([self.z_sum, self.d__sum, self.g_sum, self.d_loss_fake_sum, self.g_loss_sum])
226 | self.d_sum = merge_summary([self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
227 |
228 | self.writer = SummaryWriter(self.log_dir, self.sess.graph)
229 | sample_z = np.random.uniform(-1, 1, size=(self.sample_num, self.z_dim))
230 |
231 | counter = 0
232 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
233 | if could_load:
234 | counter = checkpoint_counter
235 | print(" [*] Load SUCCESS")
236 | else:
237 | print(" [!] Load failed...")
238 |
239 | d_accu_last_batch = .5
240 | batch_idxs = int(math.ceil(min(len(data), config.train_size) / self.glob_batch_size)) - 1
241 | total_steps = config.epoch * batch_idxs
242 | prev_time = -np.inf
243 |
244 | for epoch in xrange(config.epoch):
245 | # shuffle data before training in each epoch
246 | np.random.shuffle(data)
247 | for idx in xrange(0, batch_idxs - 1):
248 | glob_batch_files = data[idx * self.glob_batch_size:(idx + 1) * self.glob_batch_size]
249 | glob_batch = [
250 | np.load(batch_file)[0, :, :, :] for batch_file in glob_batch_files]
251 | glob_batch_images = np.array(glob_batch).astype(np.float32)[:, :, :, :, None]
252 |
253 | glob_batch_z = np.random.uniform(-1, 1, [self.glob_batch_size, self.z_dim]) \
254 | .astype(np.float32)
255 | n_eff = len(glob_batch_files)
256 |
257 | # Pad zeros if effective batch size is smaller than global batch size
258 | if n_eff != self.glob_batch_size:
259 | glob_batch_images = pad_glob_batch(glob_batch_images, self.glob_batch_size)
260 |
261 | # Update D network if accuracy in last batch <= 80%
262 | if d_accu_last_batch < .8:
263 | # Update D network
264 | _, summary_str = self.sess.run([d_optim, self.d_sum],
265 | feed_dict={self.inputs: glob_batch_images,
266 | self.z: glob_batch_z,
267 | self.n_eff: n_eff})
268 | self.writer.add_summary(summary_str, counter)
269 |
270 | # Update G network
271 | _, summary_str = self.sess.run([g_optim, self.g_sum],
272 | feed_dict={self.z: glob_batch_z,
273 | self.n_eff: n_eff})
274 | self.writer.add_summary(summary_str, counter)
275 |
276 | # Compute last batch accuracy and losses
277 | d_accu_last_batch, errD_fake, errD_real, errG \
278 | = self.sess.run([self.d_accu, self.d_loss_fake, self.d_loss_real, self.g_loss],
279 | feed_dict={self.inputs: glob_batch_images,
280 | self.z: glob_batch_z,
281 | self.n_eff: n_eff})
282 | self.writer.add_summary(summary_str, counter)
283 |
284 | # get time
285 | now_time = time.time()
286 | time_per_iter = now_time - prev_time
287 | prev_time = now_time
288 | eta = (total_steps - counter + checkpoint_counter) * time_per_iter
289 | counter += 1
290 |
291 | try:
292 | timestr = time.strftime("%H:%M:%S", time.gmtime(eta))
293 | except:
294 | timestr = '?:?:?'
295 |
296 | print("Epoch:[%3d] [%3d/%3d] Iter:[%5d] eta(h:m:s): %s, d_loss: %.8f, g_loss: %.8f, d_accu: %.4f"
297 | % (epoch, idx, batch_idxs,
298 | counter, timestr, errD_fake + errD_real, errG, d_accu_last_batch))
299 |
300 | # save model and samples every save_interval iterations
301 | if np.mod(counter, self.save_interval) == 1:
302 | samples = self.sess.run(self.sampler,feed_dict={self.z: sample_z})
303 | np.save(self.sample_dir+'/sample_{:05d}.npy'.format(counter), samples)
304 | print("[Sample] Iter {0}, saving sample size of {1}, saving checkpoint."
305 | .format(counter, self.sample_num))
306 | self.save(config.checkpoint_dir, counter)
307 |
308 | # save last checkpoint
309 | samples = self.sess.run(self.sampler,feed_dict={self.z: sample_z})
310 | np.save(self.sample_dir+'/sample_{:05d}.npy'.format(counter), samples)
311 | print("[Sample] Iter {0}, saving sample size of {1}, saving checkpoint.".format(counter, self.sample_num))
312 | self.save(config.checkpoint_dir, counter)
313 |
314 | print("{!] Training of SDFGAN Complete.")
315 |
316 | def discriminator(self, image, reuse=False):
317 | with tf.variable_scope("sdfgan_discriminator") as scope:
318 | if reuse:
319 | scope.reuse_variables()
320 |
321 | h0 = lrelu(conv3d(image, self.df_dim, name='d_h0_conv'))
322 | h1 = lrelu(self.d_bn1(conv3d(h0, self.df_dim * 2, name='d_h1_conv')))
323 | h2 = lrelu(self.d_bn2(conv3d(h1, self.df_dim * 4, name='d_h2_conv')))
324 | h3 = lrelu(self.d_bn3(conv3d(h2, self.df_dim * 8, name='d_h3_conv')))
325 | h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, name='d_h3_lin')
326 |
327 | return tf.nn.sigmoid(h4), h4
328 |
329 | def generator(self, z, reuse=False, is_train=True):
330 | with tf.variable_scope("sdfgan_generator") as scope:
331 | if reuse:
332 | scope.reuse_variables()
333 |
334 | s_d, s_h, s_w = self.image_depth, self.image_height, self.image_width
335 | s_d2, s_h2, s_w2 = conv_out_size_same(s_d, 2), conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
336 | s_d4, s_h4, s_w4 = conv_out_size_same(s_d2, 2), conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
337 | s_d8, s_h8, s_w8 = conv_out_size_same(s_d4, 2), conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
338 | s_d16, s_h16, s_w16 = conv_out_size_same(s_d8, 2), conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)
339 |
340 | # project `z` and reshape
341 | self.z_, self.h0_w, self.h0_b = linear(
342 | z, self.gf_dim * 8 * s_d16 * s_h16 * s_w16, 'g_h0_lin', with_w=True)
343 |
344 | self.h0 = tf.reshape(
345 | self.z_, [-1, s_d16, s_h16, s_w16, self.gf_dim * 8])
346 | h0 = tf.nn.relu(self.g_bn0(self.h0, train=is_train))
347 |
348 | self.h1, self.h1_w, self.h1_b = deconv3d(
349 | h0, [self.batch_size, s_d8, s_h8, s_w8, self.gf_dim * 4], name='g_h1', with_w=True)
350 | h1 = tf.nn.relu(self.g_bn1(self.h1, train=is_train))
351 |
352 | h2, self.h2_w, self.h2_b = deconv3d(
353 | h1, [self.batch_size, s_d4, s_h4, s_w4, self.gf_dim * 2], name='g_h2', with_w=True)
354 | h2 = tf.nn.relu(self.g_bn2(h2, train=is_train))
355 |
356 | h3, self.h3_w, self.h3_b = deconv3d(
357 | h2, [self.batch_size, s_d2, s_h2, s_w2, self.gf_dim * 1], name='g_h3', with_w=True)
358 | h3 = tf.nn.relu(self.g_bn3(h3, train=is_train))
359 |
360 | h4, self.h4_w, self.h4_b = deconv3d(
361 | h3, [self.batch_size, s_d, s_h, s_w, self.c_dim], name='g_h4', with_w=True)
362 |
363 | return tf.nn.tanh(h4)
364 |
365 | @property
366 | def model_dir(self):
367 | return "SDFGAN_" + "{}_{}_{}_{}".format(
368 | self.dataset_name,
369 | self.image_depth, self.image_height, self.image_width)
370 |
371 | def save(self, checkpoint_dir, step):
372 | model_name = "SDFGAN.model"
373 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
374 |
375 | if not os.path.exists(checkpoint_dir):
376 | os.makedirs(checkpoint_dir)
377 |
378 | self.saver.save(self.sess,
379 | os.path.join(checkpoint_dir, model_name),
380 | global_step=step)
381 |
382 | def load(self, checkpoint_dir):
383 | import re
384 | print(" [*] Reading checkpoints...")
385 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
386 |
387 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
388 | if ckpt and ckpt.model_checkpoint_path:
389 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
390 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
391 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
392 | print(" [*] Success to read {}".format(ckpt_name))
393 | return True, counter
394 | else:
395 | print(" [*] Failed to find a checkpoint")
396 | return False, 0
397 |
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 |
3 | try:
4 | image_summary = tf.summary.image
5 | scalar_summary = tf.summary.scalar
6 | histogram_summary = tf.summary.histogram
7 | merge_summary = tf.summary.merge
8 | SummaryWriter = tf.summary.FileWriter
9 | except:
10 | image_summary = tf.image_summary
11 | scalar_summary = tf.scalar_summary
12 | histogram_summary = tf.histogram_summary
13 | merge_summary = tf.merge_summary
14 | SummaryWriter = tf.train.SummaryWriter
15 |
16 | if "concat_v2" in dir(tf):
17 | def concat(tensors, axis, *args, **kwargs):
18 | return tf.concat_v2(tensors, axis, *args, **kwargs)
19 | else:
20 | def concat(tensors, axis, *args, **kwargs):
21 | return tf.concat(tensors, axis, *args, **kwargs)
22 |
23 |
24 | class batch_norm(object):
25 | def __init__(self, epsilon=1e-5, momentum=0.9, name="batch_norm"):
26 | with tf.variable_scope(name):
27 | self.epsilon = epsilon
28 | self.momentum = momentum
29 | self.name = name
30 |
31 | def __call__(self, x, train=True):
32 | return tf.contrib.layers.batch_norm(x,
33 | decay=self.momentum,
34 | updates_collections=None,
35 | epsilon=self.epsilon,
36 | scale=True,
37 | is_training=train,
38 | scope=self.name)
39 |
40 |
41 | def batchnorm(input, name="batchnorm"):
42 | with tf.variable_scope(name):
43 | # this block looks like it has 3 inputs on the graph unless we do this
44 | input = tf.identity(input)
45 |
46 | channels = input.get_shape()[-1]
47 | offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer())
48 | scale = tf.get_variable("scale", [channels], dtype=tf.float32, initializer=tf.random_normal_initializer(1.0, 0.02))
49 | mean, variance = tf.nn.moments(input, axes=[0, 1, 2, 3], keep_dims=False)
50 | variance_epsilon = 1e-5
51 | normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon)
52 | return normalized
53 |
54 |
55 | def conv3d(input_, output_dim,
56 | k_d=4, k_h=4, k_w=4, d_d=2, d_h=2, d_w=2, stddev=0.02,
57 | name="conv3d"):
58 | with tf.variable_scope(name):
59 | w = tf.get_variable('w', [k_d, k_h, k_w, input_.get_shape()[-1], output_dim],
60 | initializer=tf.truncated_normal_initializer(stddev=stddev))
61 | conv = tf.nn.conv3d(input_, w, strides=[1, d_d, d_h, d_w, 1], padding='SAME')
62 |
63 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
64 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
65 |
66 | return conv
67 |
68 |
69 | def deconv3d(input_, output_shape,
70 | k_d=4, k_h=4, k_w=4, d_d=2, d_h=2, d_w=2, stddev=0.02,
71 | name="deconv3d", with_w=False):
72 | with tf.variable_scope(name):
73 | # filter : [height, width, output_channels, in_channels]
74 | w = tf.get_variable('w', [k_d, k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
75 | initializer=tf.random_normal_initializer(stddev=stddev))
76 |
77 | try:
78 | deconv = tf.nn.conv3d_transpose(input_, w, output_shape=output_shape,
79 | strides=[1, d_d, d_h, d_w, 1])
80 |
81 | # Support for verisons of TensorFlow before 0.7.0
82 | except AttributeError:
83 | deconv = tf.nn.deconv3d(input_, w, output_shape=output_shape,
84 | strides=[1, d_d, d_h, d_w, 1])
85 |
86 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
87 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())
88 |
89 | if with_w:
90 | return deconv, w, biases
91 | else:
92 | return deconv
93 |
94 |
95 | def lrelu(x, leak=0.2, name="lrelu"):
96 | return tf.maximum(x, leak * x)
97 |
98 |
99 | def linear(input_, output_size, name='linear', stddev=0.02, bias_start=0.0, with_w=False):
100 | shape = input_.get_shape().as_list()
101 |
102 | with tf.variable_scope(name):
103 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,
104 | tf.random_normal_initializer(stddev=stddev))
105 | bias = tf.get_variable("bias", [output_size],
106 | initializer=tf.constant_initializer(bias_start))
107 | if with_w:
108 | return tf.matmul(input_, matrix) + bias, matrix, bias
109 | else:
110 | return tf.matmul(input_, matrix) + bias
111 |
112 |
113 | def freq_split(s, r=8, mask_type='boxed'):
114 | s = np.fft.fftn(s)
115 | dims = s.shape
116 | mids = [int(dims[i] / 2) for i in range(len(dims))]
117 |
118 | # frequency mask
119 | if mask_type == 'circular':
120 | U, V, W = np.mgrid[-mids[0]:mids[0], -mids[1]:mids[1], -mids[2]:mids[2]]
121 | D = np.sqrt(np.square(U) + np.square(V) + np.square(W))
122 | lf_mask = np.fft.ifftshift(np.less_equal(D, r * np.ones_like(D)).astype(np.complex_))
123 | hf_mask = np.ones_like(lf_mask, dtype=np.complex_) - lf_mask
124 |
125 | elif mask_type == 'boxed':
126 | U, V, W = np.mgrid[-mids[0]:mids[0], -mids[1]:mids[1], -mids[2]:mids[2]]
127 | mask = np.array([U < r, U >= -r, V < r, V >= -r, W < r, W >= -r])
128 | lf_mask = np.fft.ifftshift(np.all(mask, axis=0).astype(np.complex_))
129 | hf_mask = np.ones_like(lf_mask, dtype=np.complex_) - lf_mask
130 |
131 | else:
132 | raise Exception('mask_type undefined.')
133 |
134 | # apply mask
135 | s_lf = np.multiply(s, lf_mask)
136 | s_hf = np.multiply(s, hf_mask)
137 |
138 | s_lf = np.real(np.fft.ifftn(s_lf))
139 | s_hf = np.real(np.fft.ifftn(s_hf))
140 |
141 | return s_lf, s_hf
142 |
143 |
144 | def batch_lowpass(s_batch, r=8, mask_type='boxed'):
145 | assert len(s_batch.shape) > 3
146 | process_list = []
147 | for i in range(s_batch.shape[0]):
148 | s_lf, _ = freq_split(s_batch[i, :, :, :], r=r, mask_type=mask_type)
149 | process_list.append(s_lf)
150 |
151 | return np.array(process_list)
152 |
153 |
154 | def batch_mirr(s_batch, mirr='l'):
155 | assert len(s_batch.shape) > 3
156 | process_list = []
157 | for i in range(s_batch.shape[0]):
158 | s = s_batch[i]
159 | if mirr == 'r':
160 | sr = s[:, :, int(s.shape[2] / 2):]
161 | sl = np.flip(sr, axis=2)
162 | s = np.concatenate((sl, sr), axis=2)
163 | elif mirr == 'l':
164 | sl = s[:, :, :int(s.shape[2] / 2)]
165 | sr = np.flip(sl, axis=2)
166 | s = np.concatenate((sl, sr), axis=2)
167 | process_list.append(s)
168 |
169 | return np.array(process_list)
170 |
171 |
172 |
173 | def pad_glob_batch(glob_batch_images, glob_batch_size):
174 | """pad tensor with zeros if smaller than a full batch"""
175 | pad_shape = [j for j in glob_batch_images.shape]
176 | pad_shape[0] = glob_batch_size - pad_shape[0]
177 | pad_zeros = np.zeros(pad_shape)
178 | glob_batch_inputs = np.concatenate((glob_batch_images, pad_zeros), axis=0)
179 |
180 | return glob_batch_inputs
181 |
182 |
183 | def average_gradients(tower_grads):
184 | """Calculate the average gradient for each shared variable across all towers.
185 |
186 | Note that this function provides a synchronization point across all towers.
187 |
188 | Args:
189 | tower_grads: List of lists of (gradient, variable) tuples. The outer list
190 | is over individual gradients. The inner list is over the gradient
191 | calculation for each tower.
192 | Returns:
193 | List of pairs of (gradient, variable) where the gradient has been averaged
194 | across all towers.
195 | """
196 | # Borrowed code from
197 | # https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py
198 |
199 | average_grads = []
200 | for grad_and_vars in zip(*tower_grads):
201 | # Note that each grad_and_vars looks like the following:
202 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
203 | grads = []
204 | for g, _ in grad_and_vars:
205 | # Add 0 dimension to the gradients to represent the tower.
206 | expanded_g = tf.expand_dims(g, 0)
207 |
208 | # Append on a 'tower' dimension which we will average over below.
209 | grads.append(expanded_g)
210 |
211 | # Average over the 'tower' dimension.
212 | grad = tf.concat(axis=0, values=grads)
213 | grad = tf.reduce_mean(grad, 0)
214 |
215 | # Keep in mind that the Variables are redundant because they are shared
216 | # across towers. So .. we will just return the first tower's pointer to
217 | # the Variable.
218 | v = grad_and_vars[0][1]
219 | grad_and_var = (grad, v)
220 | average_grads.append(grad_and_var)
221 | return average_grads
222 |
--------------------------------------------------------------------------------
/stream_freqsplit.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from glob import glob
3 | import multiprocessing
4 | from os import path, makedirs
5 | from tqdm import tqdm
6 | import os
7 | import numpy as np
8 | import argparse
9 |
10 |
11 | def freq_split(s, r=8, mask_type='boxed'):
12 | '''
13 | Split frequencies of a given 3D signal
14 | :param s: input signal matrix of shape (dim_0, dim_1, dim_2)
15 | :param r: number of modes to split
16 | :param mask_type: 'boxed' or 'circular'
17 | :return: s_lf, s_hf: low and high frequencies of the signal
18 | '''
19 |
20 | s = np.fft.fftn(s)
21 | dims = s.shape
22 | mids = [int(dims[i] / 2) for i in range(len(dims))]
23 |
24 | # frequency mask
25 | if mask_type == 'circular':
26 | U, V, W = np.mgrid[-mids[0]:mids[0], -mids[1]:mids[1], -mids[2]:mids[2]]
27 | D = np.sqrt(np.square(U) + np.square(V) + np.square(W))
28 | lf_mask = np.fft.ifftshift(np.less_equal(D, r * np.ones_like(D)).astype(np.complex_))
29 | hf_mask = np.ones_like(lf_mask, dtype=np.complex_) - lf_mask
30 |
31 | elif mask_type == 'boxed':
32 | U, V, W = np.mgrid[-mids[0]:mids[0], -mids[1]:mids[1], -mids[2]:mids[2]]
33 | mask = np.array([U < r, U >= -r, V < r, V >= -r, W < r, W >= -r])
34 | lf_mask = np.fft.ifftshift(np.all(mask, axis=0).astype(np.complex_))
35 | hf_mask = np.ones_like(lf_mask, dtype=np.complex_) - lf_mask
36 |
37 | else:
38 | raise Exception('mask_type undefined.')
39 |
40 | # apply mask
41 | s_lf = np.multiply(s, lf_mask)
42 | s_hf = np.multiply(s, hf_mask)
43 |
44 | s_lf = np.real(np.fft.ifftn(s_lf))
45 | s_hf = np.real(np.fft.ifftn(s_hf))
46 |
47 | return s_lf, s_hf
48 |
49 |
50 | def process_mesh(read_path, write_path):
51 | # read mesh
52 | s = np.load(read_path)[0, :, :, :]
53 | s_lf, s_hf = freq_split(s, r=8)
54 | s_out = np.array([s_lf, s_hf])
55 | np.save(write_path, s_out)
56 |
57 |
58 | def helper(args):
59 | return process_mesh(*args)
60 |
61 |
62 | def base_name(str):
63 | """return the base name of a path, i.e folder name"""
64 | return os.path.basename(os.path.normpath(str))
65 |
66 |
67 | def cond_mkdir(directory):
68 | """Conditionally make directory if it does not exist"""
69 | if not path.exists(directory):
70 | makedirs(directory)
71 | else:
72 | print('Directory ' + directory + ' existed. Did not create.')
73 |
74 |
75 | if __name__ == '__main__':
76 |
77 | # argument parsing
78 | parser = argparse.ArgumentParser(description="process mesh to sdf using python multiprocessing module")
79 | parser.add_argument('datadir', type=str, help='Path to original dataset.')
80 | parser.add_argument('savedir', type=str, help='Path to saving processed dataset.')
81 | parser.add_argument('ncore', type=int, default=0, help='Number of cores. Enter 0 for maximum cpu cores available.')
82 | parser.add_argument('nmesh', type=int, default=0, help='Process only nmesh. Enter 0 to process all mesh.')
83 | argument = parser.parse_args()
84 |
85 | data_dir = argument.datadir
86 | save_dir = argument.savedir
87 |
88 | if argument.ncore == 0:
89 | num_cores = multiprocessing.cpu_count()
90 | print("Using max cpu cores = %d" % (num_cores))
91 | else:
92 | num_cores = argument.ncore
93 |
94 | cond_mkdir(save_dir)
95 | read_loc = glob(os.path.join(data_dir, "*"))
96 | write_loc = []
97 |
98 | for d_in in read_loc:
99 | d_out = path.join(save_dir, base_name(d_in))
100 | write_loc.append(d_out)
101 |
102 | if argument.nmesh != 0:
103 | read_loc = read_loc[:argument.nmesh]
104 | write_loc = write_loc[:argument.nmesh]
105 |
106 | args = []
107 |
108 | # retrieve mesh according to file_names with multiprocessing
109 | for counter, (read_path, write_path) in enumerate(zip(read_loc, write_loc)):
110 | args.append((read_path, write_path))
111 |
112 | print("Starting multithreaded processing with {0} cores.".format(multiprocessing.cpu_count()))
113 |
114 | pool = multiprocessing.Pool()
115 |
116 | with tqdm(total=len(read_loc)) as pbar:
117 | for _, _ in tqdm(enumerate(pool.imap_unordered(helper, args))):
118 | pbar.update()
119 | pbar.close()
120 | pool.close()
121 | pool.join()
122 |
--------------------------------------------------------------------------------
/test_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # import global variables PROJ_DIR and JOB_NAME
4 | . ./define_path.sh
5 |
6 | # variables
7 | JOB_DIR=${PROJ_DIR}/jobs/${JOB_NAME}
8 | DATASET_DIR=${PROJ_DIR}/data
9 | LOG_DIR=${JOB_DIR}/logs
10 | SAMPLE_DIR=${JOB_DIR}/samples
11 | CHECKPOINT_DIR=${JOB_DIR}/checkpoint
12 |
13 | # execute
14 | python main.py \
15 | --batch_size=64 \
16 | --dataset_dir=${DATASET_DIR} \
17 | --log_dir=${LOG_DIR} \
18 | --sample_dir=${SAMPLE_DIR} \
19 | --checkpoint_dir=${CHECKPOINT_DIR} \
20 | --dataset=${DATASET} \
21 | --sample_num=64
22 |
--------------------------------------------------------------------------------
/train_pix2pix.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # import global variables PROJ_DIR and JOB_NAME
4 | . ./define_path.sh
5 |
6 | # variables
7 | JOB_DIR=${PROJ_DIR}/jobs/${JOB_NAME}
8 | DATASET_DIR=${PROJ_DIR}/data
9 | LOG_DIR=${JOB_DIR}/logs
10 | SAMPLE_DIR=${JOB_DIR}/samples
11 | CHECKPOINT_DIR=${JOB_DIR}/checkpoint
12 |
13 | # execute
14 | python main.py \
15 | --model=pix2pix \
16 | --is_train \
17 | --epoch=500 \
18 | --batch_size=32 \
19 | --dataset_dir=${DATASET_DIR} \
20 | --log_dir=${LOG_DIR} \
21 | --sample_dir=${SAMPLE_DIR} \
22 | --checkpoint_dir=${CHECKPOINT_DIR} \
23 | --num_gpus=4 \
24 | --dataset=${DATASET}_freqsplit \
25 | --g_learning_rate=0.0005 \
26 | --d_learning_rate=0.0002 \
27 | --gan_weight=1 \
28 | --l1_weight=100 \
29 | --image_depth=64 \
30 | --beta1=0.5 \
31 |
--------------------------------------------------------------------------------
/train_sdfgan.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # import global variables PROJ_DIR and JOB_NAME
4 | . ./define_path.sh
5 |
6 | # variables
7 | JOB_DIR=${PROJ_DIR}/jobs/${JOB_NAME}
8 | DATASET_DIR=${PROJ_DIR}/data
9 | LOG_DIR=${JOB_DIR}/logs
10 | SAMPLE_DIR=${JOB_DIR}/samples
11 | CHECKPOINT_DIR=${JOB_DIR}/checkpoint
12 |
13 | # execute
14 | python main.py \
15 | --model=sdfgan \
16 | --is_train \
17 | --epoch=500 \
18 | --batch_size=64 \
19 | --dataset_dir=${DATASET_DIR} \
20 | --log_dir=${LOG_DIR} \
21 | --sample_dir=${SAMPLE_DIR} \
22 | --checkpoint_dir=${CHECKPOINT_DIR} \
23 | --num_gpus=4 \
24 | --dataset=${DATASET} \
25 | --g_learning_rate=0.0005 \
26 | --d_learning_rate=0.0002 \
27 | --image_depth=64 \
28 | --beta1=0.5 \
29 | --is_new
30 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Some codes from https://github.com/Newmu/dcgan_code
3 | """
4 | from __future__ import division
5 | import math
6 | import os
7 | import pprint
8 | import numpy as np
9 |
10 | import tensorflow as tf
11 | import tensorflow.contrib.slim as slim
12 |
13 | pp = pprint.PrettyPrinter()
14 |
15 | get_stddev = lambda x, k_h, k_w: 1 / math.sqrt(k_w * k_h * x.get_shape()[-1])
16 |
17 |
18 | def show_all_variables():
19 | model_vars = tf.trainable_variables()
20 | slim.model_analyzer.analyze_vars(model_vars, print_info=True)
21 |
22 |
23 | def create_sdfgan_samples(sess, sdfgan, config):
24 | z_sample = np.random.uniform(-1, 1, size=(config.sample_num,sdfgan.z_dim))
25 | samples = sess.run(sdfgan.sampler, feed_dict={sdfgan.z: z_sample})
26 | fname = os.path.join(config.sample_dir, "sdfgan_final_samples.npy")
27 | np.save(fname, samples)
28 |
29 | return fname
30 |
31 |
32 | def create_pix2pix_samples(sess, pix2pix, sample_in):
33 | if len(sample_in.shape) == 4:
34 | sample_in = np.expand_dims(sample_in, axis=-1)
35 | sample_out = sess.run(pix2pix.sampler, feed_dict={pix2pix.sample_inputs: sample_in})
36 |
37 | return sample_out
38 |
--------------------------------------------------------------------------------