├── README.md ├── apply_warps.py ├── create_networks.py ├── create_unwarper.py ├── make_net.py ├── pyramid.py ├── sft_utils.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Toward a Universal Model for Shape from Texture 3 | 4 | This repository contains code for: 5 | 6 | **[Toward a Universal Model for Shape from Texture](http://vision.seas.harvard.edu/sft/)** 7 |
8 | [Dor Verbin](https://scholar.harvard.edu/dorverbin) and [Todd Zickler](http://www.eecs.harvard.edu/~zickler/) 9 |
10 | CVPR 2020 11 | 12 | 13 | Please contact us by email for questions about our paper or code. 14 | 15 | 16 | 17 | 18 | ## Requirements 19 | 20 | Our code is implemented in tensorflow. It has been tested using tensorflow 1.9 but it should work for other tensorflow 1.x versions. The following packages are required: 21 | 22 | - python 3.x 23 | - tensorflow 1.x 24 | - numpy >= 1.14.0 25 | - pillow >= 5.1.0 26 | - matplotlib >= 2.2.2 (only used for plotting during training) 27 | 28 | 29 | 30 | ## Running the model 31 | 32 | To run our model execute the following: 33 | ``` 34 | python train.py -image_path -output_folder 35 | ``` 36 | 37 | In order to plot results during training using `matplotlib`, specify `-do_plot True`. To save the models, use `-do_save_model True`. 38 | 39 | Using an NVIDIA Tesla V100 GPU, training takes about 110 minutes for a `640x640` image. 40 | 41 | 42 | 43 | 44 | ## Data 45 | 46 | Our synethetic dataset was generated in [Blender](http://www.blender.org) by using cloth simulation. The texture images were mapped onto a square mesh 47 | and dropped onto a surface. After the simulation is done running, the result is rendered. Blender also enables extracting the ground truth surface normals 48 | by saving them into an `.stl` file (go to `Export > Stl` and make sure "Selection Only" is checked). 49 | 50 | We provide two files below: 51 | - A zip file containing all images can be downloaded [here](http://vision.seas.harvard.edu/sft/data/images.zip) (15.2 MB). 52 | - A zip file containing all source Blender files and `.stl` files can be downloaded [here](http://vision.seas.harvard.edu/sft/data/models.zip) (22.8 MB). 53 | 54 | The file containing all images contains five directories corresponding to the four shapes in the paper plus one containing the original (flat) images. 55 | The file containing the Blender models has four directories corresponding to the four shapes. Each one contains a Blender 56 | file with an embedded python script which can be run to automatically render all images used in the paper. Each directory also contains an `.stl` file extracted from Blender, which stores the 57 | true shape. The script can also be used to generate the images from the paper, with shading and specular highlights. 58 | The `sphere` directory also contains the Blender files and embedded python scripts used to generate the images for Figures 7 and 8 from our paper. 59 | Note: In order to use the Blender files, the two zip files must be unzipped in the same directory (only the flat directory is used by the Blender files). 60 | 61 | 62 | 63 | ## Citation 64 | 65 | For citing our paper, please use: 66 | ``` 67 | @InProceedings{verbin2020sft, 68 | author = {Verbin, Dor and Zickler, Todd}, 69 | title = {Toward a Universal Model for Shape From Texture}, 70 | booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 71 | month = {June}, 72 | year = {2020} 73 | } 74 | ``` 75 | -------------------------------------------------------------------------------- /apply_warps.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def apply_warps(img, loc, warps, shape): 4 | """ 5 | Apply affine warps to images. 6 | 7 | Inputs 8 | ------ 9 | img Input images, tensor of shape [N, H, W, C] 10 | loc A tensor holding the center pixel location for each image. Must have shape [N, 2]. 11 | warps A translation-free affine matrix for each pixel, determining the transformation 12 | of each patch. Shape [H, W, 4]. The order of elements is [m11, m12, m21, m22]. 13 | shape A tuple (H_out, W_out) specifying the output shape (must be the same for entire batch). 14 | 15 | 16 | Outputs 17 | ------- 18 | output A tensor of shape [N, H_out, W_out, C]. 19 | 20 | Description 21 | ----------- 22 | For each i = 0, ..., N-1, loc[i, :] = (x0, y0) defines a patch location. For all i, we 23 | use the matrix [warps[y0, x0, 0], warps[y0, x0, 1] ; warps[y0, x0, 2], warps[y0, x0, 3]] to warp a 24 | regular grid centered at (x0, y0). The input image, img[i, :, :, :], sampled by this grid is 25 | output[i, :, :, :]. 26 | """ 27 | 28 | # Get sizes of input image 29 | N = tf.shape(img)[0] 30 | H = tf.shape(img)[1] 31 | W = tf.shape(img)[2] 32 | 33 | # Define half width and height for output patches 34 | hw = (tf.cast(shape[1], dtype=tf.float32) - 1) / 2.0 35 | hh = (tf.cast(shape[0], dtype=tf.float32) - 1) / 2.0 36 | 37 | # Create grid for output patches 38 | x_, y_ = tf.meshgrid(tf.range(-hw, hw+1), tf.range(-hh, hh+1)) 39 | 40 | # Get warp matrix for each location in loc 41 | warps_r = tf.reshape(warps, [1, H, W, 4]) # Reshaped to [1, H, W, 4] 42 | warps_rt = tf.tile(warps_r, [N, 1, 1, 1]) # Tiled to [N, H, W, 4] 43 | warp = tf.contrib.resampler.resampler(warps_rt, loc) # Shape [N, 4] 44 | 45 | # Get elements of warp matrix (each has shape [N, 1, 1, 1]) 46 | w11 = tf.reshape(warp[:, 0], [-1, 1, 1, 1]) 47 | w12 = tf.reshape(warp[:, 1], [-1, 1, 1, 1]) 48 | w21 = tf.reshape(warp[:, 2], [-1, 1, 1, 1]) 49 | w22 = tf.reshape(warp[:, 3], [-1, 1, 1, 1]) 50 | 51 | # Reshape x, y to [1, H_out, W_out, 1] 52 | sh = tf.stack([1, shape[0], shape[1], 1], axis=0) 53 | x = tf.reshape(x_, sh) 54 | y = tf.reshape(y_, sh) 55 | 56 | # Compute warped grid (shape [N, shape[0], shape[1], 1] each) 57 | x_sample = tf.clip_by_value(tf.reshape(loc[:, 0], [-1, 1, 1, 1]) + w11 * x + w12 * y, 58 | 0.0, tf.cast(W, dtype=tf.float32)-1.0) 59 | y_sample = tf.clip_by_value(tf.reshape(loc[:, 1], [-1, 1, 1, 1]) + w21 * x + w22 * y, 60 | 0.0, tf.cast(H, dtype=tf.float32)-1.0) 61 | 62 | # Sample using the warped grid 63 | sample_points = tf.concat([x_sample, y_sample], axis=-1) 64 | output = tf.contrib.resampler.resampler(img, sample_points) 65 | 66 | return output 67 | -------------------------------------------------------------------------------- /create_networks.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from make_net import make_net 4 | from sft_utils import get_z_shape_list 5 | 6 | 7 | def create_generator_discriminator(gen_arch, disc_arch, z_shape, dim_z_global, dim_z_local_list, dim_z_periodic_list, 8 | global_mlp_hidden_units, do_tie_phases, 9 | do_print=True, disc_input=None): 10 | """ 11 | Create generator and discriminator networks. 12 | 13 | Inputs 14 | ------ 15 | gen_arch Architecture string of generator network. 16 | disc_arch Architecture string of discriminator network. 17 | z_shape Tuple (of length 2) describing input noise shape. 18 | dim_z_global Dimension of global noise 19 | dim_z_local_list List of local noise dimensions 20 | dim_z_periodic_list List of periodic noise dimensions 21 | global_mlp_hidden_units Number of hidden units for MLP predicting wave vectors from global dimensions. 22 | do_tie_phases 23 | do_print If True, pring network architectures. 24 | disc_input If specified, use as input to discriminator instead of taking the output from the 25 | generator. 26 | 27 | Outputs 28 | ------- 29 | d_ph Dictionary holding all placeholders. Keys are names, values are placeholders. 30 | d_tensors Dictionary holding all tensors. Keys are names, values are tensors. 31 | 32 | """ 33 | 34 | # Compute total local dimensions and list of spatial sizes 35 | dim_z_local = sum(dim_z_local_list) 36 | shape_list = get_z_shape_list(z_shape, len(dim_z_local_list), gen_arch) 37 | 38 | # Create phase placeholder 39 | with tf.variable_scope('phase'): 40 | is_training = tf.placeholder(tf.bool, shape=()) 41 | 42 | # Create discriminator input placeholder (unless disc_input is specified) 43 | with tf.variable_scope('discriminator_input/'): 44 | if disc_input is None: 45 | disc_input = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='disc_input') 46 | 47 | # Create input placeholders to generator (i.e. global and local dimensions) 48 | with tf.variable_scope('generator_input/'): 49 | # Create global z input 50 | if dim_z_global > 0: 51 | global_gen_input = tf.placeholder(tf.float32, shape=[None, 1, 1, dim_z_global], name='input_z_global') 52 | # Tile global z spatially and add to list 53 | gen_input = tf.tile(global_gen_input, [1, shape_list[0][0], shape_list[0][1], 1]) 54 | else: 55 | gen_input = None 56 | 57 | # Create local z input 58 | if dim_z_local > 0: 59 | local_gen_input_list = [] 60 | 61 | for i, (dzl, shape) in enumerate(zip(dim_z_local_list, shape_list)): 62 | if dzl > 0: 63 | print("Creating local tensor of shape {}x{}x{}".format(shape[0], shape[1], dzl)) 64 | local_gen_input_list.append(tf.placeholder(tf.float32, shape=[None]+list(shape)+[dzl], 65 | name='input_z_local_scale{}'.format(i))) 66 | else: 67 | local_gen_input_list.append(None) 68 | else: 69 | local_gen_input_list = [None] * len(dim_z_local_list) 70 | 71 | # Get total number of periodic dimensions 72 | dim_z_periodic = sum(dim_z_periodic_list) 73 | 74 | # If there is at least one periodic dimension, create a placeholder for phases (otherwise create list of Nones). 75 | if dim_z_periodic > 0: 76 | # Create a placeholder for each periodic map 77 | with tf.variable_scope('generator_input/'): 78 | random_phases = tf.placeholder(tf.float32, shape=[None, 1, 1, dim_z_periodic if not do_tie_phases else 2], 79 | name='random_phases') 80 | 81 | # Set variables for wave vectors if learned directly. Otherwise use MLP to predict them from global dimensions. 82 | if dim_z_global == 0: 83 | with tf.variable_scope('wave_vectors'): 84 | w = tf.get_variable('w', shape=[1, 1, 1, dim_z_periodic], 85 | initializer=tf.initializers.random_normal(mean=0.0, stddev=0.1)) 86 | theta = tf.get_variable('theta', shape=[1, 1, 1, dim_z_periodic], 87 | initializer=tf.initializers.random_uniform(minval=0.0, maxval=2*np.pi)) 88 | else: 89 | with tf.variable_scope('wave_vectors'): 90 | hidden = tf.layers.dense(global_gen_input, units=global_mlp_hidden_units, activation=tf.nn.relu) 91 | w = tf.reshape(tf.layers.dense(hidden, units=dim_z_periodic, activation=None, name='w'), 92 | [-1, 1, 1, dim_z_periodic]) 93 | theta = tf.reshape(tf.layers.dense(hidden, units=dim_z_periodic, activation=None, name='theta'), 94 | [-1, 1, 1, dim_z_periodic]) 95 | 96 | # Compute cumulative sum of periodic dimension list for selecting indices of theta and w. 97 | p_end_ind = list(np.cumsum(dim_z_periodic_list)) 98 | p_start_ind = [0] + p_end_ind[:-1] 99 | 100 | # Create list of periodic maps using learned w and theta. 101 | periodic_gen_input_list = [] 102 | with tf.variable_scope('generator_input/'): 103 | with tf.variable_scope('z_periodic'): 104 | for i, (ps, pe, shape) in enumerate(zip(p_start_ind, p_end_ind, shape_list)): 105 | # Only create map if the dimension needed to be added is > 0 106 | if pe > ps: 107 | print("Creating periodic tensor of shape {}x{}x{}".format(shape[0], shape[1], pe - ps)) 108 | kx = 0.5 * tf.nn.sigmoid(w[:, :, :, ps:pe]) * tf.cos(theta[:, :, :, ps:pe]) 109 | ky = 0.5 * tf.nn.sigmoid(w[:, :, :, ps:pe]) * tf.sin(theta[:, :, :, ps:pe]) 110 | 111 | x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0])) 112 | sinusoid_args = kx * x[np.newaxis, :, :, np.newaxis] + ky * y[np.newaxis, :, :, np.newaxis] 113 | 114 | if do_tie_phases: 115 | # TODO: For correctly matching the phases over different scales we must multiply 116 | # random_phases by by 2 ** i. The results in the paper were obtained without this factor 117 | # by mistake. Both options are implemented below, but the original one used for the paper 118 | # is active for reproducibility purposes, although both versions seem to work well. 119 | p_concat_tensor = tf.sin(2 * np.pi * sinusoid_args + 2 * np.pi * 120 | (kx*random_phases[:, :, :, :1] + ky*random_phases[:, :, :, 1:]), 121 | name='periodic_scale{}'.format(i)) 122 | # p_concat_tensor = tf.sin(2 * np.pi * sinusoid_args + 2 * np.pi * 123 | # (kx*random_phases[:, :, :, :1] + 124 | # ky*random_phases[:, :, :, 1:]) * 2 ** i, 125 | # name='periodic_scale{}'.format(i)) 126 | 127 | else: 128 | p_concat_tensor = tf.sin(2 * np.pi * sinusoid_args + random_phases[:, :, :, ps:pe], 129 | name='periodic_scale{}'.format(i)) 130 | 131 | periodic_gen_input_list.append(p_concat_tensor) 132 | 133 | else: 134 | periodic_gen_input_list.append(None) # no periodic tensor 135 | 136 | else: 137 | periodic_gen_input_list = [None] * len(dim_z_periodic_list) 138 | 139 | assert len(periodic_gen_input_list) == len(local_gen_input_list), "Must be same length (even if some are None)" 140 | 141 | # Create list of inputs to each deconvolution layer 142 | concat_tensor_list = [] 143 | for i in range(len(periodic_gen_input_list)): 144 | # Get the ith local and periodic tensor 145 | l_tensor = local_gen_input_list[i] 146 | p_tensor = periodic_gen_input_list[i] 147 | 148 | # Get list of tensors input to the ith deconvolution, ignoring any Nones. 149 | curr_concat_list = [] 150 | for t in [p_tensor, l_tensor]: 151 | if t is not None: 152 | curr_concat_list.append(t) 153 | 154 | # Concatenate the tensors and append to concat_tensor_list. 155 | if len(curr_concat_list) > 0: 156 | concat_tensor_list.append(tf.concat(curr_concat_list, axis=3, name='concat_list_{}'.format(i))) 157 | else: 158 | concat_tensor_list.append(None) 159 | 160 | # Create generator network 161 | with tf.variable_scope('generator'): 162 | gen_output = make_net(gen_input, gen_arch, is_generator=True, concat_list=concat_tensor_list.copy(), 163 | norm_layer='BN', is_training=is_training, do_print=do_print) 164 | 165 | # Create scaled version of gen_output in [0, 1] 166 | gen_sample = tf.identity(tf.nn.tanh(gen_output) * 0.5 + 0.5, name='gen_sample') 167 | 168 | # Create discriminator network. 169 | with tf.variable_scope('discriminator'): 170 | # Get discriminator output when applied to disc_input (either placeholder or input to function). 171 | disc_real = tf.identity(make_net(disc_input, disc_arch, is_generator=False, norm_layer='None', 172 | is_training=is_training, do_print=do_print), name='disc_real') 173 | # Get discriminator output when applied to the output of the generator 174 | disc_fake = tf.identity(make_net(gen_sample, disc_arch, is_generator=False, norm_layer='None', 175 | reuse=True, is_training=is_training, do_print=do_print), name='disc_fake') 176 | 177 | # Return placeholders and tensors 178 | d_ph = {'is_training': is_training, 179 | 'disc_input': disc_input, 180 | 'global_gen_input': global_gen_input, 181 | 'local_gen_input_list': local_gen_input_list} 182 | 183 | if dim_z_periodic > 0: 184 | d_ph['random_phases'] = random_phases 185 | 186 | d_tensors = {'disc_fake': disc_fake, 187 | 'disc_real': disc_real, 188 | 'gen_sample': gen_sample} 189 | 190 | return d_ph, d_tensors 191 | 192 | 193 | def create_loss(disc_real, disc_fake, wd_mult): 194 | """ 195 | Create losses for discriminator and generator. 196 | 197 | Inputs 198 | ------ 199 | disc_real Logits output by discriminator applied to unwarper output 200 | disc_fake Logits output by discriminator applied to generator output 201 | wd_mult Weight decay amount. 202 | 203 | Outputs 204 | ------- 205 | gen_loss Generator loss 206 | disc_loss Discriminator loss 207 | """ 208 | 209 | # Get all generator and discriminator kernels 210 | generator_kernels = [x for x in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 211 | scope='generator') if 'kernel' in x.name] 212 | discriminator_kernels = [x for x in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 213 | scope='discriminator') if 'kernel' in x.name] 214 | 215 | # Create weight decay elements 216 | with tf.variable_scope('weight_decay'): 217 | g_wd = wd_mult * tf.add_n([tf.nn.l2_loss(kernel) for kernel in generator_kernels]) 218 | d_wd = wd_mult * tf.add_n([tf.nn.l2_loss(kernel) for kernel in discriminator_kernels]) 219 | 220 | # Get losses 221 | with tf.variable_scope('loss'): 222 | # Generator loss is -log(sigmoid(disc_fake)) 223 | gen_loss = g_wd + tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(disc_fake), 224 | logits=disc_fake)) 225 | 226 | # Discriminator loss is -log(sigmoid(disc_real) - log(1 - sigmoid(disc_real)) 227 | disc_loss = d_wd + tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(disc_real), 228 | logits=disc_real)) + \ 229 | tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(disc_fake), 230 | logits=disc_fake)) 231 | 232 | return gen_loss, disc_loss 233 | 234 | 235 | def create_dg_optimizers(gen_loss, disc_loss, learning_rate, beta1): 236 | """ 237 | Create optimizers for discriminator and generator. 238 | 239 | Inputs 240 | ------ 241 | gen_loss Generator loss 242 | disc_loss Discriminator loss 243 | learning_rate Learning rate for Adam opetimizer 244 | beta1 beta1 parameter for Adam optimizer 245 | 246 | Outputs 247 | ------- 248 | train_gen Generator optimization operation 249 | train_disc Discriminator optimization operation 250 | """ 251 | with tf.variable_scope('optimizers'): 252 | optimizer_gen = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1) 253 | optimizer_disc = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1) 254 | 255 | # Make sure batchnorm layer parameters are updated. 256 | bn_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 257 | gen_bn_update_ops = [x for x in bn_update_ops if 'generator' in x.name] 258 | disc_bn_update_ops = [x for x in bn_update_ops if 'discriminator' in x.name] 259 | 260 | # Create optimization operations 261 | with tf.control_dependencies(gen_bn_update_ops): 262 | gen_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator') + \ 263 | tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='wave_vectors') 264 | train_gen = optimizer_gen.minimize(gen_loss, var_list=gen_vars) 265 | 266 | with tf.control_dependencies(disc_bn_update_ops): 267 | disc_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='discriminator') + \ 268 | tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='wave_vectors') 269 | train_disc = optimizer_disc.minimize(disc_loss, var_list=disc_vars) 270 | 271 | return train_gen, train_disc 272 | 273 | 274 | def create_no_nan_assertions(): 275 | """ 276 | Create operations for making sure there are no NaN values in tensors. 277 | 278 | Outputs 279 | ------- 280 | List of operations asserting there exist no NaN values. 281 | """ 282 | 283 | with tf.variable_scope('no_nan_assertions'): 284 | no_nan_assertions = [tf.verify_tensor_all_finite(t, 'Tensor {} contains bad values!'.format(t.name)) 285 | for t in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)] 286 | 287 | return no_nan_assertions 288 | -------------------------------------------------------------------------------- /create_unwarper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from pyramid import pyr2im 3 | import numpy as np 4 | 5 | 6 | def create_unwarper(H, W): 7 | """ 8 | Create unwarper. 9 | 10 | Inputs 11 | ------ 12 | H Height of image 13 | W Width of image 14 | 15 | Outputs 16 | ------- 17 | warps Warp matrices 18 | n Normal map 19 | t Tangent vector map 20 | n_smoothness_loss Smoothness loss over normals 21 | t_smoothness_loss Smoothness loss over tangent vectors 22 | integrability_loss Integrability loss 23 | 24 | Description 25 | ----------- 26 | Represent the x and y components of the normals as an image pyramid (see Section 5.2 in paper). Using the normal 27 | and tangent vector at each pixel, we define the warp matrix at each pixel and store it under `warps`. The losses 28 | are defined in Section 5.1 in the paper. 29 | """ 30 | 31 | with tf.variable_scope('warp'): 32 | 33 | # Create pyramids for nx and ny 34 | n_levels = int(np.ceil(min([np.log2(H), np.log2(W)]))) + 1 # Number of levels: 1 pixel in highest level 35 | n_pyramids = [[], []] 36 | for i in range(n_levels): 37 | div = 2 ** i 38 | for j in range(2): 39 | init = np.float32(np.random.rand(1, int(np.ceil(H / div)), int(np.ceil(W / div)), 1) - 0.5) * 0.0001 40 | if j == 0: # Only print once 41 | print("Initializing normal pyramid level {} of shape {}".format(i, init.shape)) 42 | n_pyramids[j].append(tf.get_variable(name='n{}{}'.format(j, i), initializer=init)) 43 | 44 | # Get nx and ny from pyramid representation 45 | nxu = pyr2im(n_pyramids[0])[0, :, :, 0] 46 | nyu = pyr2im(n_pyramids[1])[0, :, :, 0] 47 | nzu = np.ones((H, W), dtype=np.float32) 48 | 49 | # Normalize ||n(x, y)|| = 1 50 | n_norm = tf.sqrt(nxu ** 2 + nyu ** 2 + nzu ** 2) 51 | nx = nxu / n_norm 52 | ny = nyu / n_norm 53 | nz = nzu / n_norm 54 | 55 | # Stack to shape [H, W, 3] 56 | n = tf.stack([nx, ny, nz], axis=-1) 57 | 58 | # Create coefficients for tangent vectors 59 | coeff1_init = np.float32(np.random.rand(1, H, W, 1) - 0.5) * 0.2 60 | coeff2_init = 1.0 + np.float32(np.random.rand(1, H, W, 1) - 0.5) * 0.2 61 | coeff1 = tf.get_variable(name='c1', dtype=tf.float32, initializer=coeff1_init)[0, :, :, 0] 62 | coeff2 = tf.get_variable(name='c2', dtype=tf.float32, initializer=coeff2_init)[0, :, :, 0] 63 | 64 | # Define tangent vector to be orthogonal to normal vector 65 | txu = nz * coeff2 66 | tyu = nz * coeff1 67 | tzu = -ny * coeff1 - nx * coeff2 68 | 69 | # Normalize ||t(x, y)|| = 1 70 | t_norm = tf.sqrt(txu ** 2 + tyu ** 2 + tzu ** 2 + 1e-4) 71 | tx = txu / t_norm 72 | ty = tyu / t_norm 73 | tz = tzu / t_norm 74 | 75 | # Stack to shape [H, W, 3] 76 | t = tf.stack([tx, ty, tz], axis=-1) 77 | 78 | # Define warp matrices 79 | w11 = tx 80 | w12 = ny * tz - nz * ty 81 | w21 = ty 82 | w22 = nz * tx - nx * tz 83 | warps = tf.stack([w11, w12, w21, w22], axis=-1) 84 | 85 | # Compute smoothness losses for n, t 86 | n_smoothness_loss = tf.reduce_mean(tf.square(n[1:, :, :] - n[:-1, :, :])) + \ 87 | tf.reduce_mean(tf.square(n[:, 1:, :] - n[:, :-1, :])) 88 | t_smoothness_loss = tf.reduce_mean(tf.square(t[1:, :, :] - t[:-1, :, :])) + \ 89 | tf.reduce_mean(tf.square(t[:, 1:, :] - t[:, :-1, :])) 90 | 91 | # Compute integrability loss 92 | p = nx / nz 93 | q = ny / nz 94 | 95 | # In the Horn paper, i corresponds to x and j to y so we're flipped 96 | # relative to it. Additionally they assume x goes right and y goes up. 97 | # 98 | # We can instead integrate clockwise, and then the signs are: 99 | # 100 | # p: + p: + 101 | # q: - (i, j) q: + (i, j+1) 102 | # 103 | # 104 | # 105 | # p: - p: - 106 | # q: - (i+1, j) q: + (i+1, j+1) 107 | 108 | pi0j0 = p[:-1, :-1] # p_{i,j} 109 | pi1j0 = p[1:, :-1] # p_{i+1,j} 110 | pi0j1 = p[:-1, 1: ] # p_{i,j+1} 111 | pi1j1 = p[1:, 1: ] # p_{i+1,j+1} 112 | 113 | qi0j0 = q[:-1, :-1] # q_{i,j} 114 | qi1j0 = q[1:, :-1] # q_{i+1,j} 115 | qi0j1 = q[:-1, 1: ] # q_{i,j+1} 116 | qi1j1 = q[1:, 1: ] # q_{i+1,j+1} 117 | 118 | integrability_loss = tf.reduce_mean(tf.square( pi0j0 + pi0j1 - pi1j0 - pi1j1 119 | - qi0j0 + qi0j1 - qi1j0 + qi1j1)) 120 | 121 | return warps, n, t, n_smoothness_loss, t_smoothness_loss, integrability_loss 122 | 123 | 124 | def create_w_optimizers(shape_loss, n_learning_rate, t_learning_rate): 125 | """ 126 | Create unwarper optimizers. 127 | 128 | Inputs 129 | ------ 130 | shape_loss Overall unwarper loss (see Equation 2 in paper). 131 | n_learning_rate Learning rate for normal vectors. 132 | t_learning_rate Learning rate for tangent vectors. 133 | 134 | Outputs 135 | ------- 136 | train_shape Operations for optimizing normal and tangent vector maps. 137 | """ 138 | 139 | # Create optimizers for n, t, generator and discriminator 140 | n_optimizer_shape_vars = tf.train.AdamOptimizer(learning_rate=n_learning_rate) 141 | t_optimizer_shape_vars = tf.train.AdamOptimizer(learning_rate=t_learning_rate) 142 | 143 | # Collect shape parameters and define 144 | n_params = [x for x in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='warp') if 'n' in x.name] 145 | t_params = [x for x in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='warp') if 'c' in x.name] 146 | 147 | # Shape training operations 148 | train_shape = list() 149 | train_shape.append(n_optimizer_shape_vars.minimize(shape_loss, var_list=n_params)) 150 | train_shape.append(t_optimizer_shape_vars.minimize(shape_loss, var_list=t_params)) 151 | 152 | return train_shape 153 | -------------------------------------------------------------------------------- /make_net.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def op2layer(op): 5 | """ 6 | Create tensorflow operation from string. 7 | 8 | Inputs 9 | ------ 10 | op String describing operation. 11 | 12 | Outputs 13 | ------- 14 | Operation corresponding to op. 15 | """ 16 | if op == 'reshape': 17 | return tf.reshape 18 | elif op == 'dense': 19 | return tf.layers.dense 20 | elif op == 'conv': 21 | return tf.layers.conv2d 22 | elif op == 'deconv': 23 | return deconv_and_crop 24 | elif op == 'batchnorm': 25 | return tf.layers.batch_normalization 26 | elif op == 'instancenorm': 27 | return tf.contrib.layers.instance_norm 28 | elif op == 'relu': 29 | return tf.nn.relu 30 | elif op == 'lrelu': 31 | return tf.nn.leaky_relu 32 | elif op == 'sigmoid': 33 | return tf.nn.sigmoid 34 | else: 35 | raise ValueError('Layer op {} not implemented yet'.format(op)) 36 | 37 | 38 | def arch_string2list(arch_str, is_generator, norm_layer='BN'): 39 | """ 40 | Create architecture list from string. 41 | 42 | Inputs 43 | ------ 44 | arch_str String describing architecture 45 | is_generator True for generator, False for discriminator. 46 | norm_layer Normalization layer to use. BN for batchnorm, IN for instance norm, None for nothing. 47 | 48 | Outputs 49 | arch_list List of layers. Each entry is a dictionary with all parameters needed to create layer. 50 | 51 | Description 52 | ----------- 53 | arch_str is a string describing the architecture. The different layers are separated by an underscore. The different 54 | operations supported are: 55 | C: convolution, must have format C- where k is the kernel size and nout is the number of output channels. 56 | D: deconvolution, must have format D- where k is the kernel size and nout is the number of output channels. 57 | R: reshape, must have format R- for vectorization or R--- for unvectorizing. 58 | F: fully connected, must have format F- where nout is the number of output features. 59 | 60 | is_generator is used for determining the activation of the final layer as well as whether to use ReLU or Leaky ReLU, 61 | as described in [3] Bergmann, Jetchev, and Vollgraf, "Learning texture manifolds with the periodic spatial GAN". 62 | """ 63 | arch_list = [] 64 | last_layer = 'bottom' # Input to network 65 | layers = arch_str.split('_') 66 | for i, layer in enumerate(layers): 67 | fields = layer.split('-') 68 | ty = fields[0][0] # Layer type 69 | if ty == 'F': 70 | assert len(fields) == 2, "A fully connected layer must have format 'F-'" 71 | layer_name = 'fc{}'.format(i + 1) 72 | params = {'units': int(fields[1])} 73 | op = 'dense' 74 | layer = {'name': layer_name, 75 | 'op': 'dense'} 76 | elif ty == 'R': 77 | assert len(fields) in [2, 4], \ 78 | "A reshape layer must have 1 number for vectorizing ('R-') or 3 for creating image ('R---')" 79 | layer_name = 'reshape{}'.format(i + 1) 80 | params = {'shape': tuple([-1] + [int(x) for x in fields[1:]])} 81 | op = 'reshape' 82 | layer = {'name': layer_name, 83 | 'op': 'reshape'} 84 | elif ty in ['C', 'D']: 85 | op = 'deconv' if ty == 'D' else 'conv' 86 | assert len(fields) == 2,\ 87 | "A {} layer must have format '{}-', where k is the kernel height and width".format(op, ty) 88 | layer_name = '{}{}'.format(op, i + 1) 89 | params = {'kernel_size': int(fields[0][1:]), 90 | 'strides': (2, 2), 91 | 'filters': int(fields[1]), # Number of filters 92 | 'padding': 'same' if op == 'deconv' else 'valid'} 93 | 94 | layer = {'name': layer_name, 95 | 'op': op} 96 | else: 97 | raise ValueError('Layer type {} not implemented yet'.format(ty)) 98 | 99 | # Add name, op, params and bottom to layer dictionary 100 | layer['name'] = layer_name 101 | layer['op'] = op 102 | layer['params'] = params 103 | layer['bottom'] = last_layer 104 | 105 | # Add layer to list 106 | arch_list.append(layer) 107 | 108 | # Now add activation layer, if needed. 109 | last_layer = layer_name 110 | 111 | # If generator and any layer except for reshape layers and last layer (last before any reshape layers), 112 | # add ReLU and then norm. If discriminator and any layer except for reshape, first and last, add leaky ReLU 113 | # and then norm. If first layer only add leaky ReLU. If last layer add nothing. 114 | if is_generator and ty != 'R' and not all([l[0] == 'R' for l in layers[i+1:]]): 115 | layer = create_relu_layer(last_layer, layer_name + '_relu') 116 | arch_list.append(layer) 117 | last_layer = layer_name + '_relu' 118 | 119 | if norm_layer == 'BN': 120 | layer = create_batchnorm_layer(last_layer, layer_name + '_bn') 121 | arch_list.append(layer) 122 | last_layer = layer_name + '_bn' 123 | elif norm_layer == 'IN': 124 | layer = create_instancenorm_layer(last_layer, layer_name + '_in') 125 | arch_list.append(layer) 126 | last_layer = layer_name + '_in' 127 | 128 | elif not is_generator and ty != 'R' and i != len(layers) - 1: 129 | layer = create_lrelu_layer(last_layer, layer_name + '_lrelu') 130 | arch_list.append(layer) 131 | last_layer = layer_name + '_lrelu' 132 | 133 | if i != 0: 134 | if norm_layer == 'BN': 135 | layer = create_batchnorm_layer(last_layer, layer_name + '_bn') 136 | arch_list.append(layer) 137 | last_layer = layer_name + '_bn' 138 | elif norm_layer == 'IN': 139 | layer = create_instancenorm_layer(last_layer, layer_name + '_in') 140 | arch_list.append(layer) 141 | last_layer = layer_name + '_in' 142 | 143 | else: 144 | last_layer = layer_name 145 | 146 | return arch_list 147 | 148 | 149 | def create_batchnorm_layer(bottom, name): 150 | return create_layer(bottom, name, 'batchnorm') 151 | 152 | 153 | def create_instancenorm_layer(bottom, name): 154 | return create_layer(bottom, name, 'instancenorm') 155 | 156 | 157 | def create_relu_layer(bottom, name): 158 | return create_layer(bottom, name, 'relu') 159 | 160 | 161 | def create_lrelu_layer(bottom, name): 162 | return create_layer(bottom, name, 'lrelu') 163 | 164 | 165 | def create_layer(bottom, name, op): 166 | layer = {'name': name, 167 | 'op': op, 168 | 'bottom': bottom} 169 | return layer 170 | 171 | 172 | 173 | 174 | 175 | def gen2disc(gen_str): 176 | """ 177 | Implement DCGAN rules to convert generator string to discriminator string. 178 | """ 179 | disc_str = '' 180 | layers_inv = gen_str.split('_')[::-1] 181 | for i in range(len(layers_inv)): 182 | ty = layers_inv[i][0] 183 | if i == len(layers_inv) - 1: 184 | nout = 1 185 | else: 186 | nout_loc = layers_inv[i+1].rfind('-') + 1 187 | nout = layers_inv[i+1][nout_loc:] 188 | 189 | if ty == 'D': 190 | kernel_size = layers_inv[i][1:layers_inv[i].find('-')] 191 | disc_str += 'C{}-{}_'.format(kernel_size, nout) 192 | elif ty in ['R', 'F']: 193 | disc_str += '{}-{}_'.format(ty, nout) 194 | else: 195 | raise ValueError('Layer type {} not implemented yet'.format(ty)) 196 | 197 | return disc_str[:-1] # Remove last underscore 198 | 199 | 200 | def deconv_and_crop(bottom, kernel_size, strides, filters, padding, name): 201 | """ 202 | Operation implementing deconvolution + crop, as described in Section 4.1 of our paper. 203 | """ 204 | res = tf.layers.conv2d_transpose(bottom, filters=filters, kernel_size=kernel_size, strides=strides, padding='valid', name=name) 205 | d = kernel_size - 2 206 | return res[:, d:-d, d:-d, :] 207 | 208 | 209 | def make_net(bottom, arch_str, is_generator, concat_list=[], norm_layer='BN', reuse=False, is_training=None, 210 | do_print=True): 211 | """ 212 | Create generator/discriminator from architecture string. 213 | 214 | Inputs 215 | ------ 216 | bottom Input tensor to network. 217 | arch_str Architecture string. See arch_string2list for details on formatting. 218 | is_generator True if network being created is generator, False if discriminator. 219 | concat_list List of tensors to concatenate to each deconvolution input (in generator). 220 | norm_layer Normalization layer to use throughout network (BN for batchnorm, IN for instane norm, None for none) 221 | reuse If set to True, reuse weights. 222 | is_training If set to True, run batchnorm layers in training mode. Otherwise run in evaluation mode. 223 | do_print If set to True, print network architecture. 224 | 225 | Outputs 226 | ------- 227 | layer_tf Output tensor of network. 228 | """ 229 | 230 | assert is_training is not None, "is_training must be set" 231 | 232 | # Create arch list from arch string 233 | arch_list = arch_string2list(arch_str, is_generator=is_generator, norm_layer=norm_layer) 234 | 235 | # Make sure that the generator is being created if any input is given (other than bottom) 236 | assert len(concat_list) == 0 or is_generator, "Discriminator should not have concat_list" 237 | 238 | # Print network architecture 239 | if do_print: 240 | print('Creating {} network:'.format('generator' if is_generator else 'discriminator')) 241 | print_arch_list(arch_list) 242 | 243 | # Create dictionary of layer names 244 | name2layer = {'bottom': bottom} 245 | 246 | # Create layers one by one 247 | for layer in arch_list: 248 | if reuse: 249 | tf.get_variable_scope().reuse_variables() 250 | 251 | # Set batchnorm mode 252 | if layer['op'] == 'batchnorm': 253 | layer['params'] = {'training': is_training} 254 | 255 | # If layer has parameters, apply them. For deconvolution perform concatenation. 256 | if 'params' in layer.keys(): 257 | if layer['op'] == 'deconv' and len(concat_list) > 0: 258 | tensor_to_concat = concat_list.pop(0) # Pop the first tensor on list 259 | # If it's None, don't concat anything. Otherwise, concat it to layer's input. 260 | if tensor_to_concat is None: 261 | bottom_ = name2layer[layer['bottom']] 262 | else: 263 | deconv_bottom = name2layer[layer['bottom']] 264 | # If there is no bottom, just take tensor_to_concat (e.g. if dim_z_global = dim_z_local = 0) 265 | if deconv_bottom is not None: 266 | bottom_ = tf.concat([deconv_bottom, tensor_to_concat], axis=3) 267 | else: 268 | bottom_ = tensor_to_concat 269 | else: 270 | bottom_ = name2layer[layer['bottom']] 271 | 272 | layer_tf = op2layer(layer['op'])(bottom_, **layer['params'], name=layer['name']) 273 | if do_print: 274 | print("Creating {} layer {} with parameters {}".format(layer['op'], layer['name'], layer['params'])) 275 | elif layer['op'] == 'instancenorm': # Instace norm layer does not take a name parameter for some reason... 276 | layer_tf = op2layer(layer['op'])(name2layer[layer['bottom']], reuse=reuse, scope=layer['name']) 277 | if do_print: 278 | print("Creating {} layer {}".format(layer['op'], layer['name'])) 279 | else: 280 | layer_tf = op2layer(layer['op'])(name2layer[layer['bottom']], name=layer['name']) 281 | if do_print: 282 | print("Creating {} layer {}".format(layer['op'], layer['name'])) 283 | 284 | name2layer[layer['name']] = layer_tf 285 | 286 | assert len(concat_list) == 0, "Not all concat tensors were used (length at the end is {})".format(len(concat_list)) 287 | 288 | # For generator, find kernel size (which we assert is constant for all deconv layers) 289 | # and crop k - 2 pixels from each of the two dimensions 290 | if is_generator: 291 | for layer in arch_list: 292 | if layer['op'] == 'deconv': 293 | kernel_size = layer['params']['kernel_size'] 294 | break 295 | return layer_tf[:, :-(kernel_size-2), :-(kernel_size-2), :] 296 | 297 | return layer_tf 298 | 299 | 300 | def print_arch_list(arch_list): 301 | """ 302 | Print architecture from an architecture list (obtained from arch_string2list()). 303 | """ 304 | # Print output layer 305 | print("{:^30}".format('output')) 306 | print(" " + " " * 13 + "|" + " " * 14 + " ") 307 | 308 | # For each layer (from top to bottom) print a square with its name and its op 309 | for i, layer in enumerate(arch_list[::-1]): 310 | print(" " + "-" * 28 + " ") 311 | print("| {:16} {:10}|".format(layer['name'], layer['op'])) 312 | print(" " + "-" * 28 + " ") 313 | print(" " + " " * 13 + "|" + " " * 14 + " ") 314 | 315 | # Finally print the input layer 316 | print("{:^30}".format('input')) 317 | 318 | 319 | if __name__ == '__main__': 320 | """ 321 | Example usage. 322 | """ 323 | gen_arch = "D5-512_D5-256_D5-128_D5-64_D5-3" 324 | disc_arch = gen2disc(gen_arch) 325 | 326 | x = tf.placeholder(tf.float32, shape=[None, 28, 28, 1], name='input') 327 | 328 | dim_z = 191 329 | with tf.variable_scope('generator'): 330 | generated_img = make_net(x, gen_arch, is_generator=True, is_training=True) 331 | with tf.variable_scope('discriminator'): 332 | p_real = make_net(generated_img, gen_arch, is_generator=True, is_training=True) 333 | -------------------------------------------------------------------------------- /pyramid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | """ 5 | Implement pyramid operations as explained in Section 5.2 of our paper, and introduced in: 6 | Barron and Malik, "Shape, illumination, and reflectance from shading", TPAMI, 2015. 7 | """ 8 | 9 | 10 | def kernels(mult=1.4): 11 | """ 12 | Create binomial kernels. 13 | 14 | Inputs 15 | ------ 16 | mult Sum of each binomial kernel. See paper for explanation. 17 | 18 | Outputs 19 | ------- 20 | h_kernel Horizontal binomial kernel of shape [5, 1, 1, 1] and sum mult. 21 | w_kernel Vertical binomial kernel of shape [1, 5, 1, 1] and sum mult. 22 | """ 23 | h_kernel = mult * np.array([0.0625, 0.25, 0.3750, 0.25, 0.0625], dtype=np.float32) 24 | w_kernel = mult * np.array([0.0625, 0.25, 0.3750, 0.25, 0.0625], dtype=np.float32) 25 | 26 | return h_kernel[:, np.newaxis], w_kernel[np.newaxis, :] 27 | 28 | 29 | def downsample(img): 30 | """ 31 | Downsample an image by a factor of 2 using a binomial 5-tap filter (multiplied by a factor mult) as AA filter. 32 | 33 | Inputs 34 | ------ 35 | img Input image. Tensor of shape [N, H, W, C]. 36 | 37 | Outputs 38 | ------- 39 | Downsampled image. Tensor of shape [N, H/2, W/2, C]. 40 | """ 41 | 42 | # Blur by applying a binomial 5x5 kernel with sum mult^2 and subsample by a factor of 2. 43 | C = img.shape[3] 44 | 45 | # Get kernels, and turn them into a single kernel performing per-channel blurring 46 | h_kernel, w_kernel = kernels() 47 | kernel = np.zeros((5, 5, C, C), dtype=np.float32) 48 | for i in range(C): 49 | kernel[:, :, i, i] = h_kernel * w_kernel 50 | 51 | return tf.nn.conv2d(img, kernel, strides=(1, 2, 2, 1), padding='SAME') 52 | 53 | 54 | def upsample(img, output_shape): 55 | """ 56 | Implement the transpose of to downsample(). 57 | 58 | Inputs 59 | ------ 60 | img Input image. Tensor of shape [N, H, W, C]. 61 | output_shape Shape of output image. 62 | 63 | Outputs 64 | ------- 65 | Downsampled image. Tensor of shape [N, 2H, 2W, C]. 66 | """ 67 | # Number of channels 68 | C = img.shape[3] 69 | 70 | # Get kernel for deconvolution (transpose of convolution) 71 | h_kernel, w_kernel = kernels() 72 | kernel = np.zeros((5, 5, C, C), dtype=np.float32) 73 | for i in range(C): 74 | kernel[:, :, i, i] = h_kernel * w_kernel 75 | 76 | return tf.nn.conv2d_transpose(img, kernel, output_shape, strides=(1, 2, 2, 1), padding='SAME') 77 | 78 | 79 | def im2pyr(img, num_levels): 80 | """ 81 | Convert an image into a pyramid. 82 | 83 | Inputs 84 | ------ 85 | img Input image. Tensor of shape [N, H, W, C]. 86 | num_levels Number of pyramid levels. 87 | 88 | Outputs 89 | ------- 90 | pyr Image pyramid, represented as a list of tensors of decreasing spatial size. 91 | """ 92 | # Initialize pyramid using the original image. 93 | pyr = [img] 94 | 95 | # The next level of the pyramid is a downsampled version of the previous image. 96 | for n in range(num_levels - 1): 97 | pyr.append(downsample(pyr[-1])) 98 | 99 | return pyr 100 | 101 | 102 | def pyr2im(pyr): 103 | """ 104 | Go from pyramid to image by applying the transpose operator. 105 | 106 | Inputs 107 | ------ 108 | pyr Image pyramid, represented as a list of tensors of decreasing spatial size. 109 | 110 | Outputs 111 | ------- 112 | img Output image, computed by applying the transpose of the pyramid-generating operator to the pyramid. 113 | """ 114 | 115 | # Initialize by taking the smallest scale. 116 | img = pyr[-1] 117 | 118 | # For each larger scale, upsample and add the image in the new scale. 119 | for next_img in pyr[:-1][::-1]: 120 | img = upsample(img, next_img.shape) + next_img 121 | 122 | return img 123 | -------------------------------------------------------------------------------- /sft_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def get_z_shape_list(z_shape, num_layers, gen_arch): 4 | """ 5 | Computes the spatial size of the input to all deconv layers of the generator. 6 | 7 | Inputs 8 | ------ 9 | z_shape The spatial size of the input to the first deconvolution 10 | num_layers Number of layers in generator 11 | gen_arch Generator architecture 12 | 13 | Outputs 14 | ------- 15 | z_shape_list A list whose ith element is the spatial input size for the 16 | 17 | """ 18 | 19 | # First, make sure the kernel size of all layers is the same. 20 | ind1 = -1 21 | kernel_sizes = [] 22 | while True: 23 | ind1 = gen_arch.find('D', ind1+1) # Index of "D" 24 | ind2 = gen_arch.find('-', ind1+1) # Index of next "_" 25 | if ind1 >= 0 and ind2 < 0: 26 | k = int(gen_arch[ind1+1:]) 27 | elif ind1 >= 0: 28 | k = int(gen_arch[ind1+1:ind2]) 29 | if ind1 < 0 or ind2 < 0: 30 | break 31 | kernel_sizes.append(k) 32 | if len(set(kernel_sizes)) != 1: 33 | raise ValueError('All deconvolution kernel sizes must be the same ({})'.format(kernel_sizes)) 34 | 35 | # The (constant) kernel size 36 | kernel_size = kernel_sizes[0] 37 | 38 | # For each deconv, multiply by 2 and add kernel_size-2, see Section 4.1 of paper. 39 | return [[d * 2 ** i + kernel_size - 2 for d in z_shape] for i in range(num_layers)] 40 | 41 | 42 | def get_feed_dict(local_gen_input_list, global_gen_input, random_phases, dim_z_local_list, dim_z_global, 43 | dim_z_periodic, do_tie_phases, batch_size, z_shape_list): 44 | """ 45 | Returns a dictionary to be used as a feed_dict. Fills placeholders needed for generator. 46 | 47 | Inputs 48 | ------ 49 | local_gen_input_list List of placeholders used as local inputs to generator 50 | global_gen_input Placeholder used as global input to generator 51 | random_phases Placeholder used as phases for periodic input to generator 52 | dim_z_local_list List of local input dimension input to each deconvolution 53 | dim_z_global Global input dimension (only input to first deconvolution) 54 | dim_z_periodic Number of periodic dimensions used in total (not a list) 55 | do_tie_phases If True, only generate two numbers used as shifts. if False, 56 | generator one phase for each periodic dimension. 57 | batch_size Number of images in a batch 58 | z_shape_list List of spatial sizes input to each deconvolution layer 59 | 60 | Outputs 61 | ------- 62 | feed_dict dictionary with all placeholders as keys and the fed values as values 63 | """ 64 | 65 | # Generate local and global noise, as well as shifts (either phases or two shifts) 66 | z_local_list, z_global, phases = sample_zl_zg_phases(dim_z_local_list, dim_z_global, dim_z_periodic if not do_tie_phases else 2, 67 | batch_size, z_shape_list) 68 | 69 | feed_dict = {} 70 | 71 | # Add local maps to dictionary 72 | for local_gen_input, z_local, dz in zip(local_gen_input_list, z_local_list, dim_z_local_list): 73 | if dz > 0: 74 | feed_dict[local_gen_input] = z_local 75 | 76 | # Add global maps to dictionary 77 | if dim_z_global > 0: 78 | feed_dict[global_gen_input] = z_global 79 | 80 | # Add phases (or shifts) to dictionary. If do_tie_phases is True, multiply by an arbitrary factor 81 | # to make sure shift is large enough (rather than just leaving it at [0, 2pi]) 82 | if dim_z_periodic > 0: 83 | feed_dict[random_phases] = phases if not do_tie_phases else phases * z_shape_list[0][0] * 10.0 / (2*np.pi) 84 | 85 | return feed_dict 86 | 87 | 88 | def sample_zl_zg_phases(dim_z_local_list, dim_z_global, dim_z_periodic, batch_size, z_shape_list): 89 | """ 90 | Sample local, global noise, and phases. 91 | 92 | Inputs 93 | ------ 94 | dim_z_local_list List of local input dimension input to each deconvolution 95 | dim_z_global Global input dimension (only input to first deconvolution) 96 | dim_z_periodic Number of periodic dimensions used in total (not a list) 97 | batch_size Number of images in a batch 98 | z_shape_list List of spatial sizes input to each deconvolution layer 99 | 100 | Outputs 101 | ------- 102 | z_local_list List of uniformly sampled maps (or None values when the local dimension is 0) 103 | z_global Tensor of shape [batch_size, 1, 1, dim_z_global] sampled uniformly in [-1, 1] 104 | phases Tensor of shape [batch_size, 1, 1, dim_z_periodic] sampled uniformly in [0, 2pi] 105 | """ 106 | 107 | # Local noise maps are uniform in [0, 1] and are spatially i.i.d. 108 | z_local_list = [] 109 | for i, (z_shape, dzl) in enumerate(zip(z_shape_list, dim_z_local_list)): 110 | # If number of local dimensions is positive, sample and append to list. Otherwise append None. 111 | if dzl > 0: 112 | z_local = np.random.uniform(-1.0, 1.0, size=[batch_size, z_shape[0], z_shape[1], dzl]) 113 | z_local_list.append(z_local) 114 | else: 115 | z_local_list.append(None) 116 | else: 117 | z_local = None # If there are no local dimensions at all, return None 118 | 119 | # Global noise is uniform in [0, 1], and constant spatially. 120 | if dim_z_global > 0: 121 | z_global = np.random.uniform(-1., 1., size=[batch_size, 1, 1, dim_z_global]) 122 | else: 123 | z_global = None 124 | 125 | # Phases are uniform in [0, 2pi] 126 | if dim_z_periodic > 0: 127 | phases = np.random.uniform(0, 2*np.pi, size=[batch_size, 1, 1, dim_z_periodic]) 128 | else: 129 | phases = None 130 | 131 | return z_local_list, z_global, phases 132 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | import time 6 | import os 7 | import sys 8 | from utils import create_dir_if_needed, plot_image, save_image 9 | from PIL import Image 10 | import argparse 11 | from make_net import gen2disc 12 | from create_networks import create_generator_discriminator, create_loss, create_dg_optimizers, create_no_nan_assertions 13 | from apply_warps import apply_warps 14 | from create_unwarper import create_unwarper, create_w_optimizers 15 | from sft_utils import get_z_shape_list, get_feed_dict 16 | 17 | 18 | def param_names(): 19 | """ 20 | Create list of all parameter names. 21 | """ 22 | return ['num_steps', 'gen_arch', 'disc_arch', 'dim_z_local_list', 'dim_z_periodic_list', 'dim_z_global', 23 | 'num_gan_updates', 'num_shape_updates', 'n_learning_rate', 't_learning_rate', 24 | 'n_smoothness_weight', 't_smoothness_weight', 'int_weight', 'output_shape_gen', 'input_shape_w', 25 | 'image_path'] 26 | 27 | 28 | def str2bool(s): 29 | """ 30 | Get boolean from string. 31 | """ 32 | if s not in ['False', 'True']: 33 | raise ValueError('Not a valid boolean string') 34 | return s == 'True' 35 | 36 | 37 | 38 | 39 | def train(args): 40 | """ 41 | Train three-player game. 42 | """ 43 | 44 | # Create folders if they don't exist 45 | image_dir = os.path.join(args.output_folder, 'output_images') 46 | model_dir = os.path.join(args.output_folder, 'models') 47 | create_dir_if_needed(image_dir) 48 | create_dir_if_needed(model_dir) 49 | 50 | # For disc_arch use the discriminator for gen_arch (using DCGAN rules), unless specified otherwise 51 | if args.disc_arch == '': 52 | args.disc_arch = gen2disc(args.gen_arch) 53 | 54 | # Save parameters to file 55 | with open(os.path.join(model_dir, 'params.txt'), 'w') as f: 56 | for p in param_names(): 57 | f.write('{:<30} {:>30}\n'.format(p, str(getattr(args, p)))) 58 | 59 | tf.reset_default_graph() 60 | 61 | # Define constants for training 62 | batch_size = 25 # number of patches to use for each batch 63 | learning_rate = 0.0002 # network learning rate 64 | beta1 = 0.5 # parameter for Adam optimizer 65 | wd_mult = 1e-8 # weight decay parameter 66 | 67 | # Define constants for network architecture 68 | do_tie_phases = True 69 | global_mlp_hidden_units = 60 70 | 71 | # Compute total number of periodic dimensions 72 | dim_z_periodic = sum(args.dim_z_periodic_list) 73 | 74 | # Load input image 75 | img = np.array(Image.open(args.image_path), dtype=np.float32) / 255.0 76 | img_height = img.shape[0] 77 | img_width = img.shape[1] 78 | num_channels = img.shape[2] 79 | 80 | # Make sure number of concatenations is the same as number of deconvolutions 81 | num_deconv = args.gen_arch.count('D') 82 | assert len(args.dim_z_periodic_list) == num_deconv, \ 83 | "Number of concatenations must be the same as the number of deconvolutions" 84 | assert len(args.dim_z_local_list) == num_deconv, \ 85 | "Number of concatenations must be the same as the number of deconvolutions" 86 | 87 | # Define shape of input noise 88 | z_shape = [x // (2 ** num_deconv) for x in args.output_shape_gen] 89 | assert tuple([x * 2 ** num_deconv for x in z_shape]) == args.output_shape_gen, \ 90 | "Output size must be divisible by scale factor ({} vs. {})".format(args.output_shape_gen, 2 ** num_deconv) 91 | 92 | assert int(args.gen_arch[args.gen_arch.rfind('-')+1:]) == num_channels, \ 93 | "Output of generator must match number of channels in image" 94 | 95 | # Define placeholders for sample points and patch sizes 96 | loc = tf.placeholder(tf.float32, shape=[None, 2], name='loc') 97 | shape = tf.placeholder(tf.int64, shape=[2], name='shape') 98 | 99 | # Create unwarper 100 | warps, n, t, n_smoothness_loss, t_smoothness_loss, integrability_loss = create_unwarper(img_height, img_width) 101 | 102 | # Get unwarped patches 103 | input_img_batch = np.repeat(img[np.newaxis, :, :, :], repeats=batch_size, axis=0) 104 | disc_input = apply_warps(input_img_batch, loc, warps, shape) 105 | 106 | # Create generator and discriminator (feeding the unwarped patches to the discriminator via disc_input) 107 | dict_ph, dict_tensors = create_generator_discriminator(args.gen_arch, args.disc_arch, z_shape, args.dim_z_global, 108 | args.dim_z_local_list, args.dim_z_periodic_list, 109 | global_mlp_hidden_units, do_tie_phases, do_print=False, 110 | disc_input=disc_input) 111 | 112 | # Get placeholders from generator and discriminator 113 | is_training = dict_ph.get('is_training', None) 114 | disc_input = dict_ph.get('disc_input', None) 115 | global_gen_input = dict_ph.get('global_gen_input', None) 116 | local_gen_input_list = dict_ph.get('local_gen_input_list', None) 117 | random_phases = dict_ph.get('random_phases', None) 118 | 119 | # Get tensors from generator and discriminator 120 | disc_fake = dict_tensors.get('disc_fake', None) 121 | disc_real = dict_tensors.get('disc_real', None) 122 | gen_sample = dict_tensors.get('gen_sample', None) 123 | 124 | # Define losses for three players 125 | shape_loss = args.int_weight * integrability_loss \ 126 | + args.n_smoothness_weight * n_smoothness_loss \ 127 | + args.t_smoothness_weight * t_smoothness_loss \ 128 | + tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(disc_real), logits=disc_real)) 129 | 130 | gen_loss, disc_loss = create_loss(disc_real, disc_fake, wd_mult) 131 | 132 | # Create optimizers and update step operations for unwarper, generator and discriminator 133 | train_gen, train_disc = create_dg_optimizers(gen_loss, disc_loss, learning_rate, beta1) 134 | train_shape = create_w_optimizers(shape_loss, args.n_learning_rate, args.t_learning_rate) 135 | 136 | # Create assertions to check that no values in the network are NaN 137 | no_nan_assertions = create_no_nan_assertions() 138 | 139 | # Create list of input sizes for each deconvolution layer 140 | # Use padding as explained in Section 4.1 of the paper. 141 | z_shape_list = get_z_shape_list(z_shape, len(args.dim_z_local_list), gen_arch=args.gen_arch) 142 | 143 | # Save all network parameters (except those from the unwarper -- the normal and tangent maps are saved as images) 144 | vars_to_save = [x for x in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if 'warp' not in x.name] 145 | saver = tf.train.Saver(vars_to_save, max_to_keep=1) 146 | 147 | with tf.Session() as sess: 148 | 149 | # Initialize all variables 150 | sess.run(tf.global_variables_initializer()) 151 | 152 | # Define variables used for timing training 153 | i_prev = 0 154 | t_prev = time.time() 155 | 156 | # For each iteration, train generator and then train discriminator. Every num_gan_updates iterations, 157 | # update shape parameters for num_shape_updates iterations. 158 | for step in range(1, args.num_steps+1): 159 | 160 | ########################################## 161 | ##### ##### 162 | ##### Train generator ##### 163 | ##### ##### 164 | ########################################## 165 | 166 | # Sample Z for generator input 167 | feed_dict_gen_step = get_feed_dict(local_gen_input_list, global_gen_input, random_phases, 168 | args.dim_z_local_list, args.dim_z_global, dim_z_periodic, do_tie_phases, 169 | batch_size, z_shape_list) 170 | 171 | feed_dict_gen_step[is_training] = True # Set batchnorm mode to training 172 | 173 | # Perform step 174 | _, gl = sess.run([train_gen, gen_loss], feed_dict=feed_dict_gen_step) 175 | 176 | 177 | ########################################## 178 | ##### ##### 179 | ##### Train discriminator ##### 180 | ##### ##### 181 | ########################################## 182 | 183 | # Sample Z for generator input 184 | feed_dict_disc_step = get_feed_dict(local_gen_input_list, global_gen_input, random_phases, 185 | args.dim_z_local_list, args.dim_z_global, dim_z_periodic, do_tie_phases, 186 | batch_size, z_shape_list) 187 | 188 | # Select size of patches randomly from input_shape_w 189 | shape_ = np.array([np.random.permutation(args.input_shape_w)[0]] * 2) 190 | 191 | # Generate random crop (start_row and start_col are the topmost row and leftmost column of the crop) 192 | start_row = np.random.randint(0, img_height-shape_[0], batch_size) 193 | start_col = np.random.randint(0, img_width-shape_[1], batch_size) 194 | 195 | # Add values to feed_dict 196 | feed_dict_disc_step[shape] = shape_ 197 | feed_dict_disc_step[loc] = np.concatenate([start_col[:, np.newaxis] + (shape_[1] - 1) / 2.0, 198 | start_row[:, np.newaxis] + (shape_[0] - 1) / 2.0], axis=1) 199 | 200 | feed_dict_disc_step[is_training] = True # Set batchnorm mode to training 201 | 202 | # Perform step 203 | _, dl = sess.run([train_disc, disc_loss], feed_dict=feed_dict_disc_step) 204 | 205 | 206 | ########################################## 207 | ##### ##### 208 | ##### Train unwarper ##### 209 | ##### ##### 210 | ########################################## 211 | 212 | if step % args.num_gan_updates == 0: 213 | for _ in range(args.num_shape_updates): 214 | 215 | # Select size of patches randomly from input_shape_w 216 | shape_ = np.array([np.random.permutation(args.input_shape_w)[0]] * 2) 217 | 218 | # Generate random crop (start_row and start_col are the topmost row and leftmost column of the crop) 219 | start_row = np.random.randint(0, img_height-shape_[0], batch_size) 220 | start_col = np.random.randint(0, img_width-shape_[1], batch_size) 221 | 222 | feed_dict_shape_step = {is_training: False} # Run batchnorm in evaluation mode 223 | feed_dict_shape_step[shape] = shape_ 224 | feed_dict_shape_step[loc] = np.concatenate([start_col[:, np.newaxis] + (shape_[1] - 1) / 2.0, 225 | start_row[:, np.newaxis] + (shape_[0] - 1) / 2.0], 226 | axis=1) 227 | 228 | # Perform step 229 | _ = sess.run(train_shape, feed_dict=feed_dict_shape_step) 230 | 231 | # Make sure no tensors have any NaN every once in a while 232 | if step % 300 == 0: 233 | sess.run(no_nan_assertions) 234 | 235 | # Log progress 236 | if step % 100 == 0 or step == 1: 237 | logline = "Step {}: generator loss: {}; discriminator loss: {}; time per iteration: {} seconds."\ 238 | .format(step, gl, dl, (time.time() - t_prev) / (step - i_prev)) 239 | t_prev = time.time() 240 | i_prev = step 241 | print(logline) 242 | 243 | sys.stdout.flush() 244 | 245 | # Plot stuff every once in a while 246 | if step % 100 == 0: 247 | 248 | # Sample Z for generator input 249 | feed_dict = get_feed_dict(local_gen_input_list, global_gen_input, random_phases, args.dim_z_local_list, 250 | args.dim_z_global, dim_z_periodic, do_tie_phases, 9, z_shape_list) 251 | 252 | feed_dict[is_training] = False # Run batchnorm in evaluation mode 253 | 254 | # Get patch shapes 255 | shape_ = np.array([np.random.permutation(args.input_shape_w)[0]] * 2) 256 | 257 | # Generate random crop (start_row and start_col are the topmost row and leftmost column of the crop) 258 | start_row = np.random.randint(0, img_height-shape_[0], batch_size) 259 | start_col = np.random.randint(0, img_width-shape_[1], batch_size) 260 | 261 | 262 | feed_dict[shape] = shape_ 263 | 264 | feed_dict[loc] = np.concatenate([start_col[:, np.newaxis] + (shape_[1] - 1) / 2.0, 265 | start_row[:, np.newaxis] + (shape_[0] - 1) / 2.0], axis=1) 266 | 267 | # Get output of generator and unwarper 268 | generated_samples, disc_input_ = sess.run([gen_sample, disc_input], feed_dict=feed_dict) 269 | H, W = generated_samples.shape[1:3] 270 | b = 2 # Border, in pixels 271 | sqrt_samples = 3 # Number of rows and columns in grid of samples 272 | imgs = np.ones((sqrt_samples*H+(sqrt_samples-1)*b, sqrt_samples*W+(sqrt_samples-1)*b, 3), 273 | dtype=np.float32) 274 | for row in range(sqrt_samples): 275 | for col in range(sqrt_samples): 276 | imgs[row*(H+b):row*(H+b)+H, col*(W+b):col*(W+b)+W, :] = \ 277 | generated_samples[row*sqrt_samples+col, :, :, :] # Reshape to (3H, 3W, n_channels) 278 | 279 | H, W = disc_input_.shape[1:3] 280 | real_imgs = np.ones((sqrt_samples*H+(sqrt_samples-1)*b, sqrt_samples*W+(sqrt_samples-1)*b, 3), 281 | dtype=np.float32) 282 | for row in range(sqrt_samples): 283 | for col in range(sqrt_samples): 284 | real_imgs[row*(H+b):row*(H+b)+H, col*(W+b):col*(W+b)+W, :] = \ 285 | disc_input_[row*sqrt_samples+col, :, :, :] # Reshape to (3H, 3W, n_channels) 286 | 287 | # Get normal maps 288 | n_ = sess.run(n) 289 | n_[:, :, 1] *= -1.0 # Flip y axis in normals 290 | n_img = (n_ + 1.0) / 2.0 # Map from [-1, 1] to [0, 1] 291 | 292 | # Plot output of generator, normal vectors, output of unwarper 293 | if args.do_plot: 294 | plt.figure(1) 295 | plt.subplot(1, 3, 1) 296 | plot_image(imgs, vmin=0.0, vmax=1.0) 297 | plt.subplot(1, 3, 2) 298 | plot_image(n_img, title='niter={}'.format(step), vmin=0.0, vmax=1.0) 299 | plt.subplot(1, 3, 3) 300 | plot_image(real_imgs, vmin=0.0, vmax=1.0) 301 | plt.show() 302 | 303 | # Every one in a while, save images 304 | if step % 500 == 0: 305 | # Save normal map 306 | image_filename = os.path.join(image_dir, 'n_img_iter_{}.jpg'.format(step)) 307 | save_image(image_filename, np.uint8(n_img * 255.0)) 308 | 309 | # Save tangent map 310 | t_ = sess.run(t) 311 | t_[:, :, 1] *= -1.0 # Flip y axis in tangent vectors 312 | t_img = (t_ + 1.0) / 2.0 # Map from [-1, 1] to [0, 1] 313 | image_filename = os.path.join(image_dir, 't_img_iter_{}.jpg'.format(step)) 314 | save_image(image_filename, np.uint8(t_img * 255.0)) 315 | 316 | # Save output of generator 317 | image_filename = os.path.join(image_dir, 'generator_output_iter_{}.jpg'.format(step)) 318 | save_image(image_filename, np.uint8(imgs * 255.0)) 319 | 320 | # Save output of unwarper 321 | image_filename = os.path.join(image_dir, 'unwarper_output_iter_{}.jpg'.format(step)) 322 | save_image(image_filename, np.uint8(real_imgs * 255.0)) 323 | 324 | # Save generator and discriminator to file 325 | if args.do_save_model: 326 | if step % 1000 == 0: 327 | saver.save(sess, os.path.join(model_dir, "model{}.ckpt".format(step)), write_meta_graph=False) 328 | 329 | 330 | if __name__ == '__main__': 331 | # Define dictionary holding all parameters 332 | d = dict() 333 | d['num_steps'] = dict(type=int, default=25000, help="Number of training iterations.") 334 | d['gen_arch'] = dict(type=str, default="D5-256_D5-128_D5-64_D5-3", 335 | help="Generator architecture (see make_net.arch_string2list() for more information).") 336 | d['disc_arch'] = dict(type=str, default="", help="Discriminator architecture. If unspecified use DCGAN rules.") 337 | d['dim_z_local_list'] = dict(type=int, nargs='*', default=(0, 0, 2, 2), help="List of local dimensions.") 338 | d['dim_z_periodic_list'] = dict(type=int, nargs='*', default=(2, 2, 0, 0), help="List of periodic dimensions.") 339 | d['dim_z_global'] = dict(type=int, default=2, help="Number of global dimensions.") 340 | 341 | d['num_gan_updates'] = dict(type=int, default=20, help="Number of G/D updates at a time.") 342 | d['num_shape_updates'] = dict(type=int, default=200, help="Number of W updates at a time.") 343 | 344 | d['n_learning_rate'] = dict(type=float, default=0.001, help="Learning rate for normal vectors.") 345 | d['t_learning_rate'] = dict(type=float, default=0.05, help="Learning rate for tangent vectors.") 346 | 347 | d['n_smoothness_weight'] = dict(type=float, default=100.0, help="Smoothness loss weight for normal vectors.") 348 | d['t_smoothness_weight'] = dict(type=float, default=100.0, help="Smoothness loss weight for tangent vectors.") 349 | d['int_weight'] = dict(type=float, default=10000000.0, help="Integration loss weight.") 350 | 351 | d['output_shape_gen'] = dict(type=int, nargs=2, default=(192, 192), help="Spatial size of generator output.") 352 | d['input_shape_w'] = dict(type=int, nargs='*', default=(192, 160, 128, 96), help="Spatial size of unwarper output.") 353 | 354 | d['image_path'] = dict(type=str, default=None, help="Path of input image to use.") 355 | 356 | d['do_plot'] = dict(type=str2bool, default='False', help="Whether or not to plot results to screen.") 357 | d['do_save_model'] = dict(type=str2bool, default='False', help="Whether or not to save models to disk.") 358 | d['output_folder'] = dict(type=str, default='', help="Folder for saving results and model.") 359 | 360 | # Create argument parser 361 | parser = argparse.ArgumentParser(description='Define parameters for three-player game.') 362 | 363 | # Fill arguments from dictionary 364 | for k in d.keys(): 365 | parser.add_argument('-' + k, **d[k]) 366 | 367 | args = parser.parse_args() 368 | 369 | train(args) 370 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import cv2 5 | 6 | 7 | def create_dir_if_needed(directory): 8 | """ 9 | Create directory if it does not exist. 10 | """ 11 | if not os.path.exists(directory): 12 | os.makedirs(directory) 13 | print("Created directory {}".format(directory)) 14 | 15 | 16 | def plot_image(image, title=None, vmin=None, vmax=None): 17 | """ 18 | Helper function for image plotting. vmin and vmax are used for clipping (unless they're None). 19 | """ 20 | plt.ion() 21 | 22 | # Make sure image is either 2d (grayscale) or 3d with three channels (RGB). 23 | if len(image.shape) == 2: 24 | params = {'cmap': 'gray'} 25 | elif len(image.shape) == 3 and image.shape[2] == 1: 26 | params = {'cmap': 'gray'} 27 | image = image[:, :, 0] 28 | elif len(image.shape) == 3 and image.shape[2] == 3: 29 | params = {} 30 | else: 31 | raise ValueError("image must be 2D or 3D with three channels (but has shape {})".format(image.shape)) 32 | 33 | # If vmin and vmax are set, apply clipping 34 | if vmin is not None and vmax is not None: 35 | params['vmin'] = vmin 36 | params['vmax'] = vmax 37 | image = np.clip(image, vmin, vmax) 38 | try: 39 | plt.imshow(image, **params) 40 | if title is not None: 41 | plt.title(title) 42 | plt.show() 43 | plt.pause(0.001) 44 | except: 45 | return 46 | 47 | 48 | def save_image(filename, img): 49 | """ 50 | Save image `img` as `filename`. 51 | """ 52 | # If grayscale, save as is. If RGB, flip channels (because opencv uses BGR rather than RGB). 53 | if len(img.shape) == 2 or len(img.shape) == 3 and img.shape[2] == 1: 54 | cv2.imwrite(filename, img) 55 | else: 56 | cv2.imwrite(filename, img[:, :, ::-1]) --------------------------------------------------------------------------------