├── .DS_Store ├── img └── img.png ├── get-list.py ├── README.md ├── view.py ├── layer.py ├── image.py ├── main.py ├── reproject.py ├── network.py ├── loss.py ├── data.py └── monnet.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yueeey/sketcch3D/HEAD/.DS_Store -------------------------------------------------------------------------------- /img/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yueeey/sketcch3D/HEAD/img/img.png -------------------------------------------------------------------------------- /get-list.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | p = Path('/vol/research/zy/dataSets/shapeMVD/Chair/hires/03001627') 4 | folder_list = [x for x in p.iterdir() if x.is_dir()] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Towards Practical Sketch-Based 3D Shape Generation. 2 | 3 | ## Contents 4 | 5 | - [Introduction](#Introduction) 6 | - [Requirements](#Requirements) 7 | - [Download Dataset](#Download-Dataset) 8 | - [Results](#Results) 9 | 10 | ## Introduction 11 | 12 | This repository contains the Pytorch implementation of [Towards Practical Sketch-Based 3D Shape Generation](https://ieeexplore.ieee.org/document/9272370). 13 | 14 | You can find detailed usage instructions for training and evaluation below. 15 | 16 | If you use our code or dataset, please cite our work: 17 | 18 | @ARTICLE{sketch3d2020, 19 | author={Zhong, Yue and Qi, Yonggang and Gryaditskaya, Yulia and Zhang, Honggang and Song, Yi-Zhe}, 20 | journal={IEEE Transactions on Circuits and Systems for Video Technology}, 21 | title={Towards Practical Sketch-Based 3D Shape Generation: The Role of Professional Sketches}, 22 | year={2021}, 23 | volume={31}, 24 | number={9}, 25 | pages={3518-3528}, 26 | doi={10.1109/TCSVT.2020.3040900} 27 | } 28 | 29 | ## Requirements 30 | 31 | First you have to make sure that you have all dependencies in place. 32 | The simplest way to do so, is to use [anaconda](https://www.anaconda.com/). 33 | sss 34 | Please refer the README file in each sub-task for detailed instruction. 35 | s 36 | ## Download Dataset 37 | 38 | Download dataset is easy. Directly download from [Dataset](https://pan.baidu.com/s/1wpf6Tc7h55TN6bdUYXQsPQ) with code: fhp7. 39 | 40 | Most of our experiments are conducted on the modelsfrom a chair category of the ShapeNetCore V2. We selected these categories guided by the next principles: 1) Easy to sketch. 2) Generality. 3) View differentiability. 4) Shape genius higher than 1. 5) Large inter-category variance. We generate three categories with distinctive styles, whichwe refer to as naive, stylized and style-unified. Please refer paper for further details. 41 | 42 | 43 | ## Results 44 | 45 | We show an improved performance of deep image modeling. 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /view.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the Sketch Modeling project. 3 | 4 | Copyright (c) 2017 5 | -Zhaoliang Lun (author of the code) / UMass-Amherst 6 | 7 | This is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | This software is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with this software. If not, see . 19 | """ 20 | 21 | 22 | #import tensorflow as tf 23 | import numpy as np 24 | 25 | class Views(object): 26 | 27 | def __init__(self, filename, num_views=-1): 28 | """ 29 | self.views: V x 3 30 | self.groups G x v 31 | """ 32 | 33 | f = open(filename, 'r') 34 | 35 | f.readline() # OFF 36 | self.num_views, self.num_groups, num_edges = map(int, f.readline().split()) 37 | 38 | view_data = [] 39 | for view_id in range(self.num_views): 40 | view_data.append(list(map(float, f.readline().split()))) 41 | self.views = np.array(view_data) 42 | 43 | group_data = [] 44 | for group_id in range(self.num_groups): 45 | group_data.append(list(map(int, f.readline().split()[1:]))) 46 | self.groups = np.array(group_data) 47 | 48 | f.close() 49 | 50 | if num_views >= 0: # select views 51 | self.num_views = num_views 52 | self.num_groups = 0 53 | self.views = self.views[:self.num_views] 54 | self.groups = self.groups[:0] 55 | 56 | self.num_edges = self.num_views+self.num_groups-2 57 | self.edge_size = 2 58 | 59 | # HACK: minimal data for local testing 60 | #self.num_views = 3 61 | #self.num_groups = 0 62 | #self.num_edges = 1 63 | #self.edge_size = 2 64 | #self.views = self.views[:self.num_views] 65 | #self.groups = self.groups[:self.num_groups] 66 | 67 | #print('Views:') 68 | #print(self.views) 69 | #print('Groups:') 70 | #print(self.groups) 71 | 72 | def view2angle(view): 73 | """ 74 | input: 75 | view : 3 : (x,y,z) 76 | output: 77 | angle : 4 : (cos(theta), sin(theta), cos(phi), sin(phi)) 78 | """ 79 | 80 | r = np.linalg.norm(view) # sqrt(x^2+y^2+z^2) 81 | rxz = np.linalg.norm(view[[0,2]]) # sqrt(x^2+z^2) 82 | ct = view[1] / r # cos(theta) = y/r 83 | st = rxz / r # sin(theta) = sqrt(x^2+z^2)/r 84 | if rxz>0: 85 | cp = view[0] / rxz # cos(phi) = x / sqrt(x^2+z^2) 86 | sp = view[2] / rxz # sin(phi) = z / sqrt(x^2+z^2) 87 | else: # zenith point 88 | cp = 0.0 89 | sp = 0.0 90 | return [ct, st, cp, sp] 91 | -------------------------------------------------------------------------------- /layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the Sketch Modeling project. 3 | 4 | Copyright (c) 2017 5 | -Zhaoliang Lun (author of the code) / UMass-Amherst 6 | 7 | This is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | This software is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with this software. If not, see . 19 | """ 20 | 21 | 22 | import tensorflow as tf 23 | import numpy as np 24 | 25 | import tensorflow.contrib.layers as tf_layers 26 | import tensorflow.contrib.framework as tf_framework 27 | 28 | WEIGHT_STDDEV = 0.005 29 | WEIGHT_DECAY = 0.0001 30 | BN_DECAY = 0.997 31 | BN_EPSILON = 1e-5 32 | 33 | def leaky_relu(tensor, slope=0.2): 34 | """ 35 | input: 36 | tensor : input tensor of any shape 37 | output: 38 | result : output tensor having the same shape as input tensor 39 | """ 40 | return tf.maximum(tensor*slope, tensor) 41 | 42 | def unet_scopes(bn_scope): 43 | 44 | bn_params = { 45 | 'is_training': True, 46 | 'decay': BN_DECAY, 47 | 'epsilon': BN_EPSILON, 48 | 'trainable': False, 49 | 'updates_collections': bn_scope, 50 | } 51 | 52 | with tf_framework.arg_scope( 53 | [tf_layers.conv2d, tf_layers.fully_connected], 54 | weights_initializer=tf.truncated_normal_initializer(stddev=WEIGHT_STDDEV), 55 | weights_regularizer=tf_layers.l2_regularizer(WEIGHT_DECAY), 56 | biases_initializer=tf.zeros_initializer(), 57 | normalizer_fn=tf_layers.batch_norm, 58 | normalizer_params=bn_params, 59 | activation_fn=leaky_relu) as scope: 60 | if bn_scope is None: 61 | return scope 62 | else: 63 | with tf_framework.arg_scope([tf_layers.batch_norm], **bn_params) as scope_with_bn: 64 | return scope_with_bn 65 | 66 | def cnet_scopes(bn_scope): 67 | 68 | with tf_framework.arg_scope( 69 | [tf_layers.conv2d, tf_layers.fully_connected], 70 | weights_initializer=tf.truncated_normal_initializer(stddev=WEIGHT_STDDEV), 71 | weights_regularizer=tf_layers.l2_regularizer(WEIGHT_DECAY), 72 | biases_initializer=tf.zeros_initializer(), 73 | normalizer_fn=None, 74 | activation_fn=leaky_relu) as scope: 75 | return scope 76 | 77 | def residual_layer(inputs, kernel, scope): 78 | """ 79 | input: 80 | inputs : n x H x W x C feature maps to be passed into residual block 81 | kernel : scalar internal filter kernel size 82 | scope : string scope name 83 | output: 84 | outputs : n x H x W x C output feature maps 85 | """ 86 | 87 | channels = inputs.get_shape()[3].value 88 | layer1 = tf_layers.conv2d(inputs, num_outputs=channels, kernel_size=kernel, stride=1, scope=scope+'/layer1') 89 | layer2 = tf_layers.conv2d(layer1, num_outputs=channels, kernel_size=kernel, stride=1, scope=scope+'/layer2', activation_fn=None) 90 | outputs = layer2 + inputs 91 | return outputs 92 | 93 | def unconv_layer(inputs, num_outputs, kernel_size, stride, scope, normalizer_fn=tf_layers.batch_norm, activation_fn=tf.nn.relu): 94 | """ 95 | input: 96 | inputs : n x H x W x C feature maps to be passed into unconv layer 97 | num_outputs : scalar number of channels in output feature map 98 | kernel_size : scalar internal filter kernel size 99 | scope : string scope name 100 | normalizer_fn : function normalizer function 101 | activation_fn : function activation function 102 | output: 103 | outputs : n x H x W x C output feature maps 104 | """ 105 | 106 | # return tf_layers.conv2d_transpose(inputs, num_outputs=num_outputs, kernel_size=kernel_size, stride=stride, scope=scope, normalizer_fn=normalizer_fn, activation_fn=activation_fn) 107 | 108 | h = inputs.get_shape()[1].value 109 | w = inputs.get_shape()[2].value 110 | c = inputs.get_shape()[3].value 111 | 112 | # upsampled = tf.image.resize_bilinear(inputs, [h*stride, w*stride]) 113 | upsampled = tf.image.resize_nearest_neighbor(inputs, [h*stride, w*stride]) 114 | 115 | outputs = tf_layers.conv2d(upsampled, num_outputs=num_outputs, kernel_size=kernel_size, stride=1, scope=scope, normalizer_fn=normalizer_fn, activation_fn=activation_fn) 116 | 117 | # features = tf_layers.conv2d(upsampled, num_outputs=c, kernel_size=kernel_size, stride=1, scope=scope+'/conv1') 118 | # outputs = tf_layers.conv2d(features, num_outputs=num_outputs, kernel_size=kernel_size, stride=1, scope=scope+'/conv2', normalizer_fn=normalizer_fn, activation_fn=activation_fn) 119 | 120 | return outputs 121 | -------------------------------------------------------------------------------- /image.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the Sketch Modeling project. 3 | 4 | Copyright (c) 2017 5 | -Zhaoliang Lun (author of the code) / UMass-Amherst 6 | 7 | This is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | This software is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with this software. If not, see . 19 | """ 20 | 21 | 22 | import tensorflow as tf 23 | import numpy as np 24 | from scipy import ndimage 25 | 26 | import os 27 | 28 | ########################### image processing ########################### 29 | 30 | def normalize_image(image): 31 | # normalize to [-1.0, 1.0] 32 | if image.dtype == tf.uint8: 33 | return tf.to_float(image)/127.5-1.0 34 | elif image.dtype == tf.uint16: 35 | return tf.to_float(image)/32767.5-1.0 36 | else: 37 | return tf.to_float(image) 38 | 39 | def unnormalize_image(image, maxval=255.0): 40 | # restore image to [0.0, maxval] 41 | return (image+1.0)*maxval*0.5 42 | 43 | def saturate_image(image, dtype=tf.uint8): 44 | return tf.saturate_cast(image, dtype) 45 | 46 | def convert_to_rgb(image, channels=3): 47 | return tf.tile(image, [1,1,1,channels]) 48 | 49 | 50 | ########################### masks ########################### 51 | 52 | def extract_boolean_mask(image): 53 | """ 54 | input: 55 | image: n x H x W x C : images with value range [-1.0, 1.0] in each channel 56 | output: 57 | mask: n x H x W x 1 : boolean mask (depth channel value < 0.9) 58 | """ 59 | 60 | depth = tf.slice(image, [0,0,0,3], [-1,-1,-1,1]) 61 | shape = depth.get_shape() 62 | mask = tf.where(tf.greater(depth, 0.9), 63 | tf.constant(False, dtype=tf.bool, shape=shape), 64 | tf.constant(True, dtype=tf.bool, shape=shape)) 65 | return mask 66 | 67 | def convert_to_real_mask(bool_mask): 68 | """ 69 | input: 70 | bool_mask: boolean mask image 71 | output: 72 | real_mask: real number mask image (-1.0: false, 1.0: true) 73 | """ 74 | 75 | shape = bool_mask.get_shape() 76 | return tf.where(bool_mask, 77 | tf.constant(1.0, dtype=tf.float32, shape=shape), 78 | tf.constant(-1.0, dtype=tf.float32, shape=shape)) 79 | 80 | def convert_to_boolean_mask(real_mask): 81 | """ 82 | input: 83 | real_mask: real number mask image (-1.0: false, 1.0: true) 84 | output: 85 | bool_mask: boolean mask image 86 | """ 87 | shape = real_mask.get_shape() 88 | return tf.where(tf.greater(real_mask, 0.0), 89 | tf.constant(True, dtype=tf.bool, shape=shape), 90 | tf.constant(False, dtype=tf.bool, shape=shape)) 91 | 92 | def apply_mask(content, mask): 93 | """ 94 | input: 95 | content: n x H x W x C : image content 96 | mask: n x H x W x 1 : image mask (>0: true) 97 | output: 98 | output: use content value if mask is true; 1.0 otherwise 99 | """ 100 | channel = content.get_shape()[3].value 101 | if channel > 1: 102 | mask = tf.tile(mask, [1,1,1,channel]) 103 | return tf.where(tf.greater(mask, 0.0), content, tf.ones_like(content)) 104 | 105 | 106 | ########################### filter ########################### 107 | 108 | def get_sobel_filter(): 109 | 110 | # 3x3 sobel filter 111 | filter_v = tf.convert_to_tensor(np.array([ \ 112 | [-1.0, 0.0, 1.0], 113 | [-2.0, 0.0, 2.0], 114 | [-1.0, 0.0, 1.0]]), dtype=tf.float32) 115 | filter_h = tf.convert_to_tensor(np.array([ \ 116 | [ 1.0, 2.0, 1.0], 117 | [ 0.0, 0.0, 0.0], 118 | [-1.0, -2.0, -1.0]]), dtype=tf.float32) 119 | return filter_v, filter_h 120 | 121 | def get_dog_filter(kernel_size): 122 | 123 | # derivative of gaussian filter 124 | kernel_point = np.zeros((kernel_size, kernel_size)) 125 | kernel_point[kernel_size//2,kernel_size//2] = 1 126 | kernel_v = ndimage.filters.gaussian_filter(kernel_point, sigma=kernel_size//2, order=[0,1]) * (kernel_size*kernel_size) 127 | kernel_h = kernel_v.T 128 | filter_v = tf.constant(kernel_v, dtype=tf.float32) 129 | filter_h = tf.constant(kernel_h, dtype=tf.float32) 130 | filter_v = tf.expand_dims(tf.expand_dims(filter_v, -1), -1) 131 | filter_h = tf.expand_dims(tf.expand_dims(filter_h, -1), -1) 132 | return filter_v, filter_h 133 | 134 | def apply_edge_filter(images): 135 | """ 136 | input: 137 | images: n x H x W x C input images 138 | output: 139 | outputs: n x H x W x 1 output edge images 140 | """ 141 | 142 | if images.get_shape()[3].value == 1: 143 | gray_images = images 144 | else: 145 | gray_images = tf.image.rgb_to_grayscale(images) 146 | 147 | if not hasattr(apply_edge_filter, "filter"): 148 | apply_edge_filter.filter = get_dog_filter(15) 149 | 150 | edge_v = tf.nn.conv2d(gray_images, filter=apply_edge_filter.filter[0], strides=[1,1,1,1], padding='SAME') 151 | edge_h = tf.nn.conv2d(gray_images, filter=apply_edge_filter.filter[1], strides=[1,1,1,1], padding='SAME') 152 | outputs = tf.square(edge_v) + tf.square(edge_h) 153 | 154 | return outputs 155 | 156 | 157 | ########################### encoding ########################### 158 | 159 | def encode_batch_images(batch): 160 | """ 161 | input: 162 | batch: n x H x W x C input images batch 163 | output: 164 | packed: n x String output PNG-encoded strings 165 | """ 166 | # output: 167 | unpacked = tf.unstack(batch) 168 | num = len(unpacked) 169 | encoded = [None] * num 170 | for k in range(num): 171 | encoded[k] = tf.image.encode_png(unpacked[k]) 172 | return tf.stack(encoded) 173 | 174 | def encode_raw_batch_images(batch): 175 | """ 176 | input: 177 | batch: n x H x W x C input raw images batch 178 | output: 179 | packed: n x String output PNG-encoded strings 180 | """ 181 | return encode_batch_images(saturate_image(unnormalize_image(batch))) 182 | 183 | def write_image(name, image): 184 | """ 185 | input: 186 | name: String file name 187 | image: String PNG-encoded string 188 | """ 189 | path = os.path.dirname(name) 190 | if not os.path.exists(path): 191 | os.makedirs(path) 192 | file = open(name, 'wb') 193 | file.write(image) 194 | file.close() 195 | 196 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the Sketch Modeling project. 3 | 4 | Copyright (c) 2017 5 | -Zhaoliang Lun (author of the code) / UMass-Amherst 6 | 7 | This is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | This software is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with this software. If not, see . 19 | """ 20 | 21 | 22 | import tensorflow as tf 23 | 24 | import time 25 | import os 26 | 27 | import data 28 | import monnet as mn 29 | import view as vw 30 | 31 | FLAGS = tf.app.flags.FLAGS 32 | 33 | tf.app.flags.DEFINE_boolean('train', False, 34 | """Flag for training routine.""") 35 | tf.app.flags.DEFINE_boolean('test', False, 36 | """Flag for testing routine.""") 37 | tf.app.flags.DEFINE_boolean('encode', False, 38 | """Flag for encoding routine.""") 39 | tf.app.flags.DEFINE_boolean('predict_normal', True, 40 | """Flag for predicting normal.""") 41 | tf.app.flags.DEFINE_boolean('continuous_view', False, 42 | """Flag for using continuous view architecture.""") 43 | tf.app.flags.DEFINE_boolean('no_adversarial', False, 44 | """Flag for adversarial loss term.""") 45 | tf.app.flags.DEFINE_integer('batch_size', 2, 46 | """Number of images to process in a batch.""") 47 | tf.app.flags.DEFINE_integer('image_size', 256, 48 | """Size of images to be learned.""") 49 | tf.app.flags.DEFINE_integer('sketch_variations', 4, 50 | """Number of variations on input source.""") 51 | tf.app.flags.DEFINE_string('sketch_views', 'F', 52 | """Views used in sketch input ( [F]ront / [T]op / [S]ide )""") 53 | tf.app.flags.DEFINE_float('max_epochs', 100.0, 54 | """Maximum epochs for optimization.""") 55 | tf.app.flags.DEFINE_float('gpu_fraction', 0.9, 56 | """Upper-bound fraction of GPU memory usage.""") 57 | tf.app.flags.DEFINE_string('data_dir', '/vol/research/zy/dataSets/shapeMVD/Chair/', 58 | """Directory containing training/testing images.""") 59 | tf.app.flags.DEFINE_string('sketch_dir', '/vol/research/ycres/zy/dataSets/occ/ShapeNet/', 60 | """Directory containing training/testing images.""") 61 | tf.app.flags.DEFINE_string('sketch_set', '/naive_mad', 62 | """Directory containing training/testing images.""") 63 | tf.app.flags.DEFINE_string('train_dir', '/vol/research/zyres/3dv/baselines/SketchModeling/Network/Checkpoint/', 64 | """Directory where to write training logs.""") 65 | tf.app.flags.DEFINE_string('test_dir', '/vol/research/zyres/3dv/baselines/SketchModeling/Network/output/sty_mad1/', 66 | """Directory where to write testing logs.""") 67 | tf.app.flags.DEFINE_string('check_dir', '/vol/research/zyres/3dv/baselines/SketchModeling/Network/output/sty_mad/', 68 | """Directory where to write testing logs.""") 69 | tf.app.flags.DEFINE_string('encode_dir', './../../../../Data/CharacterDraw/encode/', 70 | """Directory where to write encoding logs.""") 71 | tf.app.flags.DEFINE_string('view_file', 'view.off', 72 | """File with view points information.""") 73 | 74 | def main(argv=None): 75 | 76 | print('start running...') 77 | start_time = time.time() 78 | 79 | ############################################ build graph ############################################ 80 | 81 | monnet = mn.MonNet(FLAGS) 82 | 83 | if int(FLAGS.train) + int(FLAGS.test) + int(FLAGS.encode) != 1: 84 | print('please specify \'train\' or \'test\' or \'encode\'') 85 | return 86 | 87 | views = vw.Views(os.path.join(FLAGS.data_dir, 'view', FLAGS.view_file)) 88 | 89 | if FLAGS.train: 90 | train_names, train_sources, train_targets, train_masks, train_angles, num_train_shapes = data.load_train_data(FLAGS, views) 91 | valid_names, valid_sources, valid_targets, valid_masks, valid_angles, num_valid_shapes = data.load_validate_data(FLAGS, views) 92 | 93 | with tf.variable_scope("monnet") as scope: 94 | monnet.build_network(\ 95 | names=train_names, 96 | sources=train_sources, 97 | targets=train_targets, 98 | masks=train_masks, 99 | angles=train_angles, 100 | views=views, 101 | is_training=True) 102 | scope.reuse_variables() # sharing weights 103 | monnet.build_network(\ 104 | names=valid_names, 105 | sources=valid_sources, 106 | targets=valid_targets, 107 | masks=valid_masks, 108 | angles=valid_angles, 109 | views=views, 110 | is_validation=True) 111 | elif FLAGS.test: 112 | test_names, test_sources, test_targets, test_masks, test_angles, num_test_shapes = data.load_test_data(FLAGS, views) 113 | 114 | with tf.variable_scope("monnet") as scope: 115 | monnet.build_network(\ 116 | names=test_names, 117 | sources=test_sources, 118 | targets=test_targets, 119 | masks=test_masks, 120 | angles=test_angles, 121 | views=views, 122 | is_testing=True) 123 | elif FLAGS.encode: 124 | encode_names, encode_sources, encode_targets, encode_masks, encode_angles, num_encode_shapes = data.load_encode_data(FLAGS, views) 125 | 126 | with tf.variable_scope("monnet") as scope: 127 | monnet.build_network(\ 128 | names=encode_names, 129 | sources=encode_sources, 130 | targets=encode_targets, 131 | masks=encode_masks, 132 | angles=encode_angles, 133 | views=views, 134 | is_encoding=True) 135 | 136 | 137 | ############################################ compute graph ############################################ 138 | 139 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_fraction) 140 | 141 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, 142 | log_device_placement=False, 143 | allow_soft_placement=True)) as sess: 144 | 145 | if FLAGS.train: 146 | monnet.train(sess, views, num_train_shapes, num_valid_shapes) 147 | elif FLAGS.test: 148 | monnet.test(sess, views, num_test_shapes) 149 | elif FLAGS.encode: 150 | monnet.encode(sess, views, num_encode_shapes) 151 | 152 | sess.close() 153 | 154 | duration = time.time() - start_time 155 | print('total running time: %.1f\n' % duration) 156 | 157 | 158 | if __name__ == '__main__': 159 | tf.app.run() -------------------------------------------------------------------------------- /reproject.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the Sketch Modeling project. 3 | 4 | Copyright (c) 2017 5 | -Zhaoliang Lun (author of the code) / UMass-Amherst 6 | 7 | This is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | This software is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with this software. If not, see . 19 | """ 20 | 21 | 22 | import tensorflow as tf 23 | import numpy as np 24 | 25 | import os 26 | 27 | class ReProj(object): 28 | 29 | def __init__(self): 30 | self.proj = np.identity(4) 31 | self.view = np.identity(4) 32 | 33 | def set_ortho_projection(self, l=-2.5, r=2.5, b=-2.5, t=2.5, n=0.1, f=5.0): 34 | """ 35 | args: 36 | l: left 37 | r: right 38 | b: bottom 39 | t: top 40 | n: near 41 | f: far 42 | ref: https://www.opengl.org/sdk/docs/man2/xhtml/glOrtho.xml 43 | """ 44 | self.proj = np.array([ \ 45 | [2.0/(r-l), 0.0, 0.0, -(r+l)/(r-l)], 46 | [0.0, 2.0/(t-b), 0.0, -(t+b)/(t-b)], 47 | [0.0, 0.0, -2.0/(f-n), -(f+n)/(f-n)], 48 | [0.0, 0.0, 0.0, 1.0 ]]) 49 | self.proj_inv = np.linalg.inv(self.proj) 50 | 51 | def set_viewpoint(self, viewpoint): 52 | """ 53 | args: 54 | viewpoint: eye position (assuming center at origin, up on Y axis) 55 | ref: http://www.ibm.com/support/knowledgecenter/ssw_aix_53/com.ibm.aix.opengl/doc/openglrf/gluLookAt.htm 56 | """ 57 | E = viewpoint 58 | C = np.array([0.0, 0.0, 0.0]) 59 | U = np.array([0.0, 1.0, 0.0]) 60 | L = C-E; 61 | L = L/np.linalg.norm(L) 62 | S = np.cross(L, U) 63 | if np.linalg.norm(S) == 0: 64 | U = np.array([0.0, 0.0, -1.0]) 65 | S = np.cross(L, U) 66 | S = S/np.linalg.norm(S) 67 | Up = np.cross(S, L) 68 | R = np.identity(4) 69 | R[0, 0:3] = S 70 | R[1, 0:3] = Up 71 | R[2, 0:3] = -L 72 | T = np.identity(4) 73 | T[0:3, 3] = -E 74 | self.view = np.dot(R, T) 75 | self.view_inv = np.linalg.inv(self.view) 76 | 77 | def transform(self, depth): 78 | """ 79 | input: 80 | depth: H x W depth map with value range [-1, 1] 81 | output: 82 | points: (HxW) x 3 point set 83 | """ 84 | H = depth.shape[0] 85 | W = depth.shape[1] 86 | num_points = np.count_nonzero(depth<1.0) 87 | valid_points = [None] * num_points 88 | point_id = 0 89 | for u in range(W): 90 | for v in range(H): 91 | if depth[v,u] < 1.0: 92 | valid_points[point_id] = [(u*2.0+1.0-W)/W, (H-v*2.0-1.0)/H, depth[v,u], 1.0] 93 | point_id += 1 94 | if num_points<=0: 95 | valid_points = np.empty([0,4]) 96 | points = np.dot(self.view_inv, np.dot(self.proj_inv, np.array(valid_points).T))[0:3,:].T 97 | return points 98 | 99 | def export_ply(filename, points, normals=None): 100 | """ 101 | args: 102 | filename: string file name 103 | points: (HxW) x 3 point set 104 | normals: (HxW) x 3 point set 105 | """ 106 | path = os.path.dirname(filename) 107 | if not os.path.exists(path): 108 | os.makedirs(path) 109 | f = open(filename, 'w') 110 | f.write('ply\n') 111 | f.write('format ascii 1.0\n') 112 | f.write('element vertex %d\n' % points.shape[0]) 113 | f.write('property float x\n') 114 | f.write('property float y\n') 115 | f.write('property float z\n') 116 | if normals is not None: 117 | f.write('property float nx\n') 118 | f.write('property float ny\n') 119 | f.write('property float nz\n') 120 | f.write('end_header\n') 121 | for k in range(points.shape[0]): 122 | f.write('%f %f %f\n' % (points[k,0], points[k,1], points[k,2])) 123 | if normals is not None: 124 | f.write('%f %f %f\n' % (normals[k,0], normals[k,1], normals[k,2])) 125 | f.close() 126 | 127 | def transform_tensor(predicts, views): 128 | """ 129 | input: 130 | predicts : (n*V) x H x W x 4 predicted tensor (in n batches & V views) 131 | views : V x 3 view point positions (numpy array) 132 | output: 133 | points : (n*V) x H x W x 3 re-projected points position tensor 134 | dirs : (n*V) x H x W x 3 re-projected normals direction tensor 135 | """ 136 | 137 | shape = predicts.get_shape().as_list() 138 | num_views = views.shape[0] 139 | num_batches = shape[0] / num_views 140 | 141 | # calculate reprojection matrix 142 | 143 | reproj = ReProj() 144 | reproj.set_ortho_projection() 145 | 146 | xform_per_view = [None] * num_views 147 | rotate_per_view = [None] * num_views 148 | for view_id in range(num_views): 149 | reproj.set_viewpoint(views[view_id,:]) 150 | xform_per_view[view_id] = tf.constant(np.dot(reproj.view_inv, reproj.proj_inv), dtype=tf.float32) # [4 x 4] * V 151 | rotate_per_view[view_id] = tf.constant(reproj.view_inv, dtype=tf.float32) # [4 x 4] * V 152 | 153 | # separate depth/normal by views 154 | 155 | predicts_per_view = tf.transpose(tf.reshape(predicts, [-1, num_views, shape[1], shape[2], shape[3]]), [1, 0, 2, 3, 4]) # V x n x H x W x 4 156 | depths_per_view = tf.unstack(tf.slice(predicts_per_view, [0,0,0,0,3], [-1,-1,-1,-1,1])) # [n x H x W x 1] * V 157 | normals_per_view = tf.unstack(tf.slice(predicts_per_view, [0,0,0,0,0], [-1,-1,-1,-1,3])) # [n x H x W x 3] * V 158 | 159 | # calculate projected coordinates 160 | 161 | H = shape[1] 162 | W = shape[2] 163 | vec_u = tf.constant([(u*2.0+1.0-W)/W for u in range(W)]) # W 164 | vec_v = tf.constant([(H-v*2.0-1.0)/H for v in range(H)]) # H 165 | mat_u = tf.tile(tf.reshape(vec_u, [1,1,-1,1]), (num_batches,H,1,1)) # n x H x W x 1 166 | mat_v = tf.tile(tf.reshape(vec_v, [1,-1,1,1]), (num_batches,1,W,1)) # n x H x W x 1 167 | mat_w = tf.ones([num_batches, H, W, 1]) 168 | 169 | homo_points_per_view = [tf.concat([mat_u, mat_v, mat_d, mat_w], 3) for mat_d in depths_per_view] # [n x H x W x 4] * V 170 | homo_dirs_per_view = [tf.concat([mat_n, mat_w], 3) for mat_n in normals_per_view] # [n x H x W x 4] * V 171 | 172 | # transform points 173 | 174 | points_per_view = [None] * num_views 175 | dirs_per_view = [None] * num_views 176 | for view_id in range(num_views): 177 | xformed = tf.matmul(tf.reshape(homo_points_per_view[view_id], [-1,4]), xform_per_view[view_id], transpose_b=True) # (n*H*W) x 4 178 | rotated = tf.matmul(tf.reshape(homo_dirs_per_view[view_id], [-1,4]), rotate_per_view[view_id], transpose_b=True) # (n*H*W) x 4 179 | points_per_view[view_id] = tf.slice(tf.reshape(xformed, [-1,H,W,4]), [0,0,0,0], [-1,-1,-1,3]) # n x H x W x 3 180 | dirs_per_view[view_id] = tf.slice(tf.reshape(rotated, [-1,H,W,4]), [0,0,0,0], [-1,-1,-1,3]) # n x H x W x 3 181 | 182 | # organize output points 183 | 184 | points = tf.transpose(tf.stack(points_per_view), [1,0,2,3,4]) # n x v x H x W x 3 185 | points = tf.reshape(points, [-1, H, W, 3]) # (n*V) x H x W x 3 186 | 187 | dirs = tf.transpose(tf.stack(dirs_per_view), [1,0,2,3,4]) # n x v x H x W x 3 188 | dirs = tf.reshape(dirs, [-1, H, W, 3]) # (n*V) x H x W x 3 189 | 190 | return points, dirs -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the Sketch Modeling project. 3 | 4 | Copyright (c) 2017 5 | -Zhaoliang Lun (author of the code) / UMass-Amherst 6 | 7 | This is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | This software is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with this software. If not, see . 19 | """ 20 | 21 | 22 | import tensorflow as tf 23 | import numpy as np 24 | 25 | import tensorflow.contrib.layers as tf_layers 26 | 27 | import layer 28 | import image 29 | 30 | def generateUNet(images, num_output_views, num_output_channels): 31 | """ 32 | input: 33 | images : n x H x W x Ci input images ( 256 x 256 x Ci ) 34 | num_output_views : int number of output views 35 | num_output_channels : int number of output image channels 36 | output: 37 | results : (n*m) x H x W x Co output images ( 256 x 256 x Co ) 38 | features : n x D output features ( 512 ) 39 | """ 40 | 41 | ###### encoding ###### 42 | 43 | e1 = tf_layers.conv2d(images, num_outputs= 64, kernel_size=4, stride=2, scope='e1', normalizer_fn=None) # 128 x 128 x 64 44 | e2 = tf_layers.conv2d( e1, num_outputs=128, kernel_size=4, stride=2, scope='e2') # 64 x 64 x 128 45 | e3 = tf_layers.conv2d( e2, num_outputs=256, kernel_size=4, stride=2, scope='e3') # 32 x 32 x 256 46 | e4 = tf_layers.conv2d( e3, num_outputs=256, kernel_size=4, stride=2, scope='e4') # 16 x 16 x 512 47 | e5 = tf_layers.conv2d( e4, num_outputs=256, kernel_size=4, stride=2, scope='e5') # 8 x 8 x 512 48 | e6 = tf_layers.conv2d( e5, num_outputs=512, kernel_size=4, stride=2, scope='e6') # 4 x 4 x 512 49 | e7 = tf_layers.conv2d( e6, num_outputs=512, kernel_size=4, stride=2, scope='e7') # 2 x 2 x 512 50 | 51 | num_batches = images.get_shape()[0].value 52 | features = tf.reshape(e7, [num_batches, -1]) # 2048 53 | 54 | ###### decoding ###### 55 | 56 | nc = num_output_channels 57 | rpv = [None] * num_output_views # results per view 58 | for view in range(num_output_views): 59 | 60 | with tf.variable_scope('decoder_%d' % view): 61 | d6 = tf_layers.dropout(layer.unconv_layer( e7, num_outputs=512, kernel_size=4, stride=2, scope='d6')) # 4 x 4 x 512 62 | d5 = tf_layers.dropout(layer.unconv_layer(tf.concat([d6, e6], 3), num_outputs=256, kernel_size=4, stride=2, scope='d5')) # 8 x 8 x 512 63 | d4 = layer.unconv_layer(tf.concat([d5, e5], 3), num_outputs=256, kernel_size=4, stride=2, scope='d4') # 16 x 16 x 512 64 | d3 = layer.unconv_layer(tf.concat([d4, e4], 3), num_outputs=256, kernel_size=4, stride=2, scope='d3') # 32 x 32 x 256 65 | d2 = layer.unconv_layer(tf.concat([d3, e3], 3), num_outputs=128, kernel_size=4, stride=2, scope='d2') # 64 x 64 x 128 66 | d1 = layer.unconv_layer(tf.concat([d2, e2], 3), num_outputs= 64, kernel_size=4, stride=2, scope='d1') # 128 x 128 x 64 67 | rpv[view] = layer.unconv_layer(tf.concat([d1, e1], 3), num_outputs= nc, kernel_size=4, stride=2, scope='re', normalizer_fn=None, activation_fn=tf.tanh) 68 | 69 | height = images.get_shape()[1].value 70 | width = images.get_shape()[2].value 71 | results = tf.reshape(tf.transpose(tf.stack(rpv), [1,0,2,3,4]), [-1, height, width, nc]) 72 | 73 | return results, features 74 | 75 | def generateCNet(images, angles, num_output_channels): 76 | """ 77 | input: 78 | images : n x H x W x Ci input images ( 256 x 256 x Ci ) 79 | angles : n x 4 output viewing angle parameters 80 | num_output_channels : int number of output image channels 81 | output: 82 | results : n x H x W x Co output images ( 256 x 256 x Co ) 83 | features : n x D output features ( 512 ) 84 | """ 85 | 86 | ###### encoding ###### 87 | 88 | e1 = tf_layers.conv2d(images, num_outputs= 64, kernel_size=4, stride=2, scope='e1', normalizer_fn=None) # 128 x 128 x 64 89 | e2 = tf_layers.conv2d( e1, num_outputs=128, kernel_size=4, stride=2, scope='e2') # 64 x 64 x 128 90 | e3 = tf_layers.conv2d( e2, num_outputs=256, kernel_size=4, stride=2, scope='e3') # 32 x 32 x 256 91 | e4 = tf_layers.conv2d( e3, num_outputs=512, kernel_size=4, stride=2, scope='e4') # 16 x 16 x 512 92 | e5 = tf_layers.conv2d( e4, num_outputs=512, kernel_size=4, stride=2, scope='e5') # 8 x 8 x 512 93 | e6 = tf_layers.conv2d( e5, num_outputs=512, kernel_size=4, stride=2, scope='e6') # 4 x 4 x 512 94 | e7 = tf_layers.conv2d( e6, num_outputs=512, kernel_size=4, stride=2, scope='e7') # 2 x 2 x 512 95 | 96 | num_batches = images.get_shape()[0].value 97 | ifeat = tf.reshape(e7, [num_batches, -1]) # 2048 98 | ifeat = tf_layers.fully_connected(ifeat, 2048, scope='ifc') # 2048 99 | features = ifeat 100 | 101 | vfeat = tf_layers.stack( 102 | angles, 103 | tf_layers.fully_connected, 104 | [64, # 64 105 | 64, # 64 106 | 64], # 64 107 | scope='vfc') 108 | 109 | ###### decoding ###### 110 | 111 | nc = num_output_channels 112 | mp = 1 # multiplier for filter size (should be something close to the square root of number of output views) 113 | 114 | feat = tf_layers.stack( 115 | tf.concat([ifeat, vfeat], 1), 116 | tf_layers.fully_connected, 117 | [1024*mp, # 1024*mp 118 | 1024*mp, # 1024*mp 119 | 2048*mp], # 2048*mp 120 | scope='fc') 121 | feat = tf.reshape(feat, [-1, 2, 2, 512*mp]) # 2 x 2 x 512*mp 122 | 123 | #d6 = layer.unconv_layer( feat, num_outputs=512*mp, kernel_size=4, stride=2, scope='d6') # 4 x 4 x 512*mp 124 | #d5 = layer.unconv_layer(tf.concat([d6, e6], 3), num_outputs=512*mp, kernel_size=4, stride=2, scope='d5') # 8 x 8 x 512*mp 125 | #d4 = layer.unconv_layer(tf.concat([d5, e5], 3), num_outputs=512*mp, kernel_size=4, stride=2, scope='d4') # 16 x 16 x 512*mp 126 | #d3 = layer.unconv_layer(tf.concat([d4, e4], 3), num_outputs=256*mp, kernel_size=4, stride=2, scope='d3') # 32 x 32 x 256*mp 127 | #d2 = layer.unconv_layer(tf.concat([d3, e3], 3), num_outputs=128*mp, kernel_size=4, stride=2, scope='d2') # 64 x 64 x 128*mp 128 | #d1 = layer.unconv_layer(tf.concat([d2, e2], 3), num_outputs= 64*mp, kernel_size=4, stride=2, scope='d1') # 128 x 128 x 64*mp 129 | #results = layer.unconv_layer(tf.concat([d1, e1], 3), num_outputs= nc, kernel_size=4, stride=2, scope='re', normalizer_fn=None, activation_fn=tf.tanh) 130 | 131 | d6 = layer.unconv_layer(feat, num_outputs=512*mp, kernel_size=4, stride=2, scope='d6') # 4 x 4 x 512*mp 132 | d5 = layer.unconv_layer(d6, num_outputs=512*mp, kernel_size=4, stride=2, scope='d5') # 8 x 8 x 512*mp 133 | d4 = layer.unconv_layer(d5, num_outputs=512*mp, kernel_size=4, stride=2, scope='d4') # 16 x 16 x 512*mp 134 | d3 = layer.unconv_layer(d4, num_outputs=256*mp, kernel_size=4, stride=2, scope='d3') # 32 x 32 x 256*mp 135 | d2 = layer.unconv_layer(d3, num_outputs=128*mp, kernel_size=4, stride=2, scope='d2') # 64 x 64 x 128*mp 136 | d1 = layer.unconv_layer(d2, num_outputs= 64*mp, kernel_size=4, stride=2, scope='d1') # 128 x 128 x 64*mp 137 | results = layer.unconv_layer(d1, num_outputs= nc, kernel_size=4, stride=2, scope='re', normalizer_fn=None, activation_fn=tf.tanh) 138 | 139 | return results, features 140 | 141 | def discriminate(data): 142 | """ 143 | intput: 144 | data : n x H x W x C data to be discriminated ( 256 x 256 x C ) 145 | output: 146 | probs : n probabilities being real 147 | """ 148 | 149 | d1 = tf_layers.conv2d(data, num_outputs= 64, kernel_size=4, stride=2, scope='d1', normalizer_fn=None) # 128 x 128 x 64 150 | d2 = tf_layers.conv2d(d1, num_outputs=128, kernel_size=4, stride=2, scope='d2') # 64 x 64 x 128 151 | d3 = tf_layers.conv2d(d2, num_outputs=256, kernel_size=4, stride=2, scope='d3') # 32 x 32 x 256 152 | d4 = tf_layers.conv2d(d3, num_outputs=512, kernel_size=4, stride=2, scope='d4') # 16 x 16 x 512 153 | d5 = tf_layers.conv2d(d4, num_outputs=512, kernel_size=4, stride=2, scope='d5') # 8 x 8 x 512 154 | d6 = tf_layers.conv2d(d5, num_outputs=512, kernel_size=4, stride=2, scope='d6') # 4 x 4 x 512 155 | d7 = tf_layers.conv2d(d6, num_outputs=512, kernel_size=4, stride=2, scope='d7') # 2 x 2 x 512 156 | 157 | feature = tf.reshape(d7, [-1, 2048]) # 2048 158 | probs = tf_layers.fully_connected(feature, 1, scope='fc', normalizer_fn=None, activation_fn=tf.sigmoid) # 1 159 | probs = tf.reshape(probs, [-1]) 160 | 161 | return probs -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the Sketch Modeling project. 3 | 4 | Copyright (c) 2017 5 | -Zhaoliang Lun (author of the code) / UMass-Amherst 6 | 7 | This is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | This software is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with this software. If not, see . 19 | """ 20 | 21 | 22 | import tensorflow as tf 23 | import numpy as np 24 | 25 | import image 26 | import reproject as rp 27 | import view as vw 28 | 29 | def compute_depth_loss(predicts, targets, mask, normalized=True): 30 | """ 31 | input: 32 | predicts : n x H x W x 1 predicted depths 33 | targets : n x H x W x 1 ground-truth depths 34 | mask : n x H x W x 1 boolean mask 35 | normalized : boolean whether output loss should be normalized by pixel number 36 | output: 37 | loss : scalar loss value 38 | """ 39 | 40 | num_batches = predicts.get_shape()[0].value 41 | num_channels = predicts.get_shape()[3].value 42 | 43 | diff = tf.abs(predicts-targets) # L-1 loss 44 | # diff = tf.square(predicts-targets) # L-2 loss 45 | diff = tf.boolean_mask(diff, tf.squeeze(mask, [3])) 46 | if normalized: 47 | depth_loss = tf.reduce_mean(diff) * (num_batches*num_channels) 48 | else: 49 | depth_loss = tf.reduce_sum(diff) 50 | 51 | return depth_loss 52 | 53 | def compute_normal_loss(predicts, targets, mask, normalized=True): 54 | """ 55 | input: 56 | predicts : n x H x W x 3 predicted normals 57 | targets : n x H x W x 3 ground-truth normals 58 | mask : n x H x W x 1 boolean mask 59 | normalized : boolean whether output loss should be normalized by pixel number 60 | output: 61 | loss : scalar loss value 62 | """ 63 | 64 | num_batches = predicts.get_shape()[0].value 65 | num_channels = predicts.get_shape()[3].value 66 | 67 | # with unit length 1-n_1*n_2 = 0.5*||n_1-n_2||^2 68 | diff = tf.square(predicts-targets) 69 | diff = tf.boolean_mask(diff, tf.squeeze(mask, [3])) 70 | if normalized: 71 | normal_loss = tf.reduce_mean(diff) * (num_batches*num_channels) 72 | else: 73 | normal_loss = tf.reduce_sum(diff) 74 | 75 | return normal_loss 76 | 77 | def compute_mask_loss(predicts, targets, normalized=True): 78 | """ 79 | input: 80 | predicts : n x H x W x C generated masks (-1: false, 1: true) 81 | targets : n x H x W x C ground-truth masks (-1: false, 1: true) 82 | normalized : boolean whether output loss should be normalized by pixel number 83 | output: 84 | loss : scalar loss value 85 | """ 86 | 87 | p = predicts * 0.5 + 0.5 # convert to probability 88 | z = targets * 0.5 + 0.5 89 | # L = -z*log(p)-(1-z)*log(1-p) 90 | mask_loss = tf.reduce_sum(-tf.multiply(tf.log(tf.maximum(1e-6, p)), z)-tf.multiply(tf.log(tf.maximum(1e-6, 1-p)), 1-z)) 91 | 92 | if normalized: 93 | mask_shape = predicts.get_shape().as_list() 94 | num_pixels = np.prod(mask_shape[1:]) 95 | mask_loss /= num_pixels 96 | 97 | return mask_loss 98 | 99 | def compute_pixel_loss(predicts, targets, normalized=True): 100 | """ 101 | input: 102 | predicts : n x H x W x C predicted images 103 | targets : n x H x W x C ground-truth images 104 | normalized : boolean whether output loss should be normalized by pixel number 105 | output: 106 | loss : scalar loss value 107 | """ 108 | 109 | num_batches = predicts.get_shape()[0].value 110 | num_channels = predicts.get_shape()[3].value 111 | 112 | diff = tf.abs(predicts-targets) # L-1 loss 113 | # diff = tf.square(predicts-targets) # L-2 loss 114 | if normalized: 115 | pixel_loss = tf.reduce_mean(diff) * (num_batches*num_channels) 116 | else: 117 | pixel_loss = tf.reduce_sum(diff) 118 | 119 | return pixel_loss 120 | 121 | def compute_consist_loss(contents, normalized=True): 122 | """ 123 | input: 124 | contents : n x H x W x 4 normal/depth maps (nx, ny, nz, d) 125 | normalized : boolean whether output loss should be normalized by pixel number 126 | output: 127 | loss : scalar loss value 128 | """ 129 | 130 | # Lx = | kappa * nx + dZdx * nz | 131 | # Ly = | kappa * ny + dZdy * nz | 132 | 133 | shape = contents.get_shape().as_list() 134 | num_batches = shape[0] 135 | H = shape[1] 136 | W = shape[2] 137 | kappaX = 5.0 / H # NOTE: view radius = 2.5 138 | kappaY = 5.0 / W 139 | 140 | filter_x = tf.convert_to_tensor(np.array([\ 141 | [1.0, 0.0, -1.0], 142 | [4.0, 0.0, -4.0], 143 | [1.0, 0.0, -1.0]]), dtype=tf.float32) 144 | filter_y = tf.convert_to_tensor(np.array([\ 145 | [-1.0, -4.0, -1.0], 146 | [0.0, 0.0, 0.0], 147 | [1.0, 4.0, 1.0]]), dtype=tf.float32) 148 | filter_x = tf.expand_dims(tf.expand_dims(filter_x, -1), -1) 149 | filter_y = tf.expand_dims(tf.expand_dims(filter_y, -1), -1) 150 | 151 | nx, ny, nz, d = tf.split(contents, 4, axis=3) 152 | 153 | dZdx = tf.nn.conv2d(d, filter=filter_x, strides=[1,1,1,1], padding='SAME') 154 | dZdy = tf.nn.conv2d(d, filter=filter_y, strides=[1,1,1,1], padding='SAME') 155 | 156 | Lx = tf.abs(tf.scalar_mul(kappaX, nx) + tf.multiply(dZdx, nz)) 157 | Ly = tf.abs(tf.scalar_mul(kappaY, ny) + tf.multiply(dZdy, nz)) 158 | 159 | if normalized: 160 | consist_loss = (tf.reduce_mean(Lx)+tf.reduce_mean(Ly)) * num_batches 161 | else: 162 | consist_loss = tf.reduce_sum(Lx)+tf.reduce_sum(Ly) 163 | 164 | return consist_loss 165 | 166 | 167 | def compute_corres_geom_loss(predicts, corres, views): 168 | """ 169 | input: 170 | predicts : (n*v) x H x W x 4 predicted images 171 | corres : n x G x M x v correspondence point indices (G groups of M correspondences across v span views) 172 | views : vw.Views view points data 173 | output: 174 | loss : scalar loss value 175 | """ 176 | 177 | if views.num_edges == 0: 178 | return 0 179 | 180 | position_factor = 1.0 181 | direction_factor = 1.0 182 | 183 | shape = predicts.get_shape().as_list() 184 | H = shape[1] 185 | W = shape[2] 186 | num_batches = shape[0] / views.num_views 187 | num_samples = corres.get_shape()[2].value 188 | 189 | points, dirs = rp.transform_tensor(predicts, views.views) # (n*V) x H x W x 3 190 | 191 | batch_points = tf.unpack(tf.reshape(points, [-1,views.num_views,H,W,3])) # [V x H x W x 3] * n 192 | batch_dirs = tf.unpack(tf.reshape(dirs, [-1,views.num_views,H,W,3])) # [V x H x W x 3] * n 193 | batch_corres = tf.unpack(corres) # [G x M x v] * n 194 | 195 | batch_losses = [None] * num_batches 196 | for batch_id in range(num_batches): 197 | all_points = tf.reshape(batch_points[batch_id], [-1,3]) # (V*H*W) x 3 198 | all_dirs = tf.reshape(batch_dirs[batch_id], [-1,3]) # (V*H*W) x 3 199 | all_corres = tf.reshape(batch_corres[batch_id], [-1]) # (G*M*v) 200 | slice_points = tf.reshape(tf.gather(all_points, all_corres), [views.num_edges,-1,views.edge_size,3]) # G x M x v x 3 201 | slice_dirs = tf.reshape(tf.gather(all_dirs, all_corres), [views.num_edges,num_samples,views.edge_size,3]) # G x M x v x 3 202 | 203 | # compute position loss as variance of reprojected point positions across nearby views 204 | normalized_points = slice_points - tf.tile(tf.reduce_mean(slice_points, reduction_indices=2, keep_dims=True), [1,1,views.edge_size,1]) # G x M x v x 3 205 | position_loss = tf.reduce_mean(tf.multiply(normalized_points, normalized_points))*3.0 206 | 207 | # compute direction loss as mean(1-dot(n,n)) for all pairs of reprojected directions across nearby views 208 | lensq_dirs = tf.maximum(tf.reduce_sum(tf.multiply(slice_dirs, slice_dirs), reduction_indices=3, keep_dims=True), 1e-3) 209 | normalized_dirs = tf.multiply(slice_dirs, tf.tile(tf.rsqrt(lensq_dirs), (1,1,1,3))) 210 | transposed = tf.reshape(tf.transpose(normalized_dirs, [2,0,1,3]), [views.edge_size, -1]) # V x (G*M*3) 211 | direction_loss = 1.0 - tf.reduce_mean(tf.matmul(transposed, transposed, transpose_b=True))*(1.0/(views.num_edges*num_samples)) 212 | 213 | batch_losses[batch_id] = position_factor*position_loss + direction_factor*direction_loss 214 | 215 | loss = tf.reduce_sum(tf.stack(batch_losses)) 216 | 217 | return loss 218 | 219 | def compute_corres_mask_loss(predicts, corres, views): 220 | """ 221 | input: 222 | predicts : (n*v) x H x W x 1 predicted masks 223 | corres : n x G x M x v correspondence point indices (G groups of M correspondences across v span views) 224 | views : vw.Views view points data 225 | output: 226 | loss : scalar loss value 227 | """ 228 | 229 | if views.num_edges == 0: 230 | return 0 231 | 232 | shape = predicts.get_shape().as_list() 233 | H = shape[1] 234 | W = shape[2] 235 | num_batches = shape[0] / views.num_views 236 | num_samples = corres.get_shape()[2].value 237 | 238 | probs = predicts*0.5+0.5 # [-1,1] => [0,1] 239 | 240 | batch_probs = tf.unpack(tf.reshape(probs, [-1,views.num_views,H,W,1])) # [V x H x W x 1] * n 241 | batch_corres = tf.unpack(corres) # [G x M x v] * n 242 | 243 | batch_losses = [None] * num_batches 244 | for batch_id in range(num_batches): 245 | all_probs = tf.reshape(batch_probs[batch_id], [-1,1]) # (V*H*W) x 1 246 | all_corres = tf.reshape(batch_corres[batch_id], [-1]) # (G*M*v) 247 | slice_probs = tf.reshape(tf.gather(all_probs, all_corres), [views.num_edges,-1,views.edge_size,1]) # G x M x v x 1 248 | 249 | # compute mask loss as Jensen-Shannon divergence of predicted mask probabilities across nearby views 250 | mask_loss = tf.reduce_mean( compute_entropy(tf.reduce_mean(slice_probs, reduction_indices=1)) - tf.reduce_mean(compute_entropy(slice_probs), reduction_indices=1) ) 251 | 252 | batch_losses[batch_id] = mask_loss 253 | 254 | loss = tf.reduce_sum(tf.pack(batch_losses)) 255 | 256 | return loss 257 | 258 | def compute_entropy(tensor): 259 | """ 260 | input: 261 | tensor : any shape tensor 262 | output: 263 | entropy : tensor having the same shape with input tensor 264 | """ 265 | 266 | entropy = - tf.multiply(tensor, tf.log(tensor+1e-6)) - tf.multiply(1.0-tensor, tf.log(1.0-tensor+1e-6)) 267 | return entropy -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the Sketch Modeling project. 3 | 4 | Copyright (c) 2017 5 | -Zhaoliang Lun (author of the code) / UMass-Amherst 6 | 7 | This is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | This software is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with this software. If not, see . 19 | """ 20 | 21 | 22 | import tensorflow as tf 23 | import numpy as np 24 | 25 | import os 26 | import math 27 | 28 | import image 29 | import view as vw 30 | from pathlib import Path 31 | 32 | NUM_CORRESPONDENCES = 1024 33 | 34 | def load_data(config, views, shape_list, shuffle=True, batch_size=-1): 35 | """ 36 | input: 37 | config tf.app.flags command line arguments 38 | views vw.View view points information 39 | shape_list list of string input shape name list 40 | shuffle bool whether input shape list should be shuffled 41 | output: 42 | name_batch n x string shape names 43 | source_batch n x H x W x Ci source images 44 | target_batch (n*m) x H x W x Co target images in m views 45 | mask_batch (n*m) x H x W x 1 target boolean masks in m views 46 | angle_batch (n*m) x 4 target viewing angle params in m views 47 | num_shapes int number of loaded shapes 48 | """ 49 | 50 | if batch_size==-1: 51 | batch_size = config.batch_size 52 | 53 | # handle affix 54 | 55 | num_source_views = len(config.sketch_views) 56 | # source_prefix_list = ['sketch/' for view in range(num_source_views)] 57 | # source_interfix_list = ['/sketch-%c' % v for v in config.sketch_views] 58 | # if config.test: 59 | # sketch_variation = '0' 60 | # else: 61 | # sketch_variation_queue = tf.train.string_input_producer(['%d' % v for v in range(config.sketch_variations)], shuffle=True) 62 | # sketch_variation = sketch_variation_queue.dequeue() 63 | # source_suffix_list = ['-'+sketch_variation+'.png' for view in range(num_source_views)] 64 | 65 | num_dnfs_views = max(2, len(config.sketch_views)) 66 | dnfs_prefix_list = ['dnfs/' for view in range(num_dnfs_views)] 67 | dnfs_interfix_list = ['/dnfs-%d' % config.image_size for view in range(num_dnfs_views)] 68 | dnfs_suffix_list = ['-%d.png' % view for view in range(num_dnfs_views)] 69 | 70 | num_dn_views = 12 71 | dn_prefix_list = ['dn/' for view in range(num_dn_views)] 72 | dn_interfix_list = ['/dn-%d' % config.image_size for view in range(num_dn_views)] 73 | dn_suffix_list = ['-%d.png' % view for view in range(num_dn_views)] 74 | 75 | num_target_views = num_dnfs_views + num_dn_views 76 | target_prefix_list = dnfs_prefix_list + dn_prefix_list 77 | target_interfix_list = dnfs_interfix_list + dn_interfix_list 78 | target_suffix_list = dnfs_suffix_list + dn_suffix_list 79 | num_target_views = views.num_views 80 | 81 | # build input queue 82 | 83 | if config.continuous_view and config.test: 84 | shape_list_queue = tf.train.input_producer([name for name in shape_list for view in range(num_target_views)], shuffle=False) 85 | else: 86 | shape_list_queue = tf.train.input_producer(shape_list, shuffle=shuffle) 87 | 88 | # load data from queue 89 | 90 | shape_name = shape_list_queue.dequeue() 91 | extension = 'jpg' 92 | # import pudb; pu.db 93 | image_dir = config.sketch_dir+shape_name+config.sketch_set 94 | if 'human' in config.sketch_set: 95 | file_glob = image_dir + '/' + '*.' + extension 96 | source_files_list = tf.matching_files(file_glob) 97 | else: 98 | file_glob_base = image_dir + '/base/' + '*.' + extension 99 | file_glob_bias = image_dir + '/bias/' + '*.' + extension 100 | # import pudb; pu.db 101 | source_files_list_base = tf.matching_files(file_glob_base) 102 | source_files_list_bias = tf.matching_files(file_glob_bias) 103 | # print('############################################') 104 | # print(source_files_list_base) 105 | source_files_list = tf.concat([source_files_list_base,source_files_list_bias], 0) 106 | 107 | source_file_queue = tf.train.string_input_producer(source_files_list, shuffle=True) 108 | source_file = source_file_queue.dequeue() 109 | # source_files = [config.data_dir+shape_name+source_prefix_list_base[view]+source_interfix_list[view]+source_suffix_list[view] for view in range(num_source_views)] 110 | # source_files = [config.data_dir+source_prefix_list[view]+shape_name+source_interfix_list[view]+source_suffix_list[view] for view in range(num_source_views)] 111 | if not config.continuous_view: 112 | target_files = [config.data_dir+target_prefix_list[view]+shape_name+target_interfix_list[view]+target_suffix_list[view] for view in range(num_target_views)] 113 | target_angles = tf.zeros([num_target_views, 4]) 114 | else: 115 | angle_list = [vw.view2angle(view) for view in views.views] 116 | view_list_queue = tf.train.slice_input_producer([angle_list, target_prefix_list, target_interfix_list, target_suffix_list], shuffle=(not config.test)) 117 | target_files = [config.data_dir+view_list_queue[1]+shape_name+view_list_queue[2]+view_list_queue[3]] # only one single image 118 | target_angles = [view_list_queue[0]] 119 | 120 | # decode source imagess 121 | # source_images = [tf.image.decode_png(tf.read_file(file), channels=1, dtype=tf.uint8) for file in source_files] 122 | # source_image = tf.concat(source_images, 2) # put multi-view images into different channels 123 | source_images = [tf.image.resize_images(tf.image.decode_png(tf.read_file(source_file), channels=1, dtype=tf.uint8), [256, 256], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)] 124 | source_image = tf.concat(source_images, 2) # put multi-view images into different channels 125 | source_image = image.normalize_image(tf.slice(source_image, [0,0,0], [config.image_size, config.image_size, -1])) # just do a useless slicing to establish size 126 | source_image = tf.concat([source_image, tf.image.flip_left_right(source_image)], 2) # HACK: add horizontally flipped image as input 127 | 128 | # decode target images 129 | 130 | if not config.test: 131 | target_images = tf.stack([tf.image.decode_png(tf.read_file(file), channels=4, dtype=tf.uint16) for file in target_files]) 132 | target_images = image.normalize_image(tf.slice(target_images, [0,0,0,0], [-1,config.image_size, config.image_size, -1])) 133 | else: 134 | target_images = tf.ones([len(target_files), config.image_size, config.image_size, 4]) # dummy target for testing 135 | target_masks = image.extract_boolean_mask(target_images) 136 | 137 | if config.predict_normal: 138 | # pre-process normal background 139 | target_shape = target_images.get_shape().as_list() 140 | target_background = tf.concat([tf.zeros(target_shape[:-1]+[2]), tf.ones(target_shape[:-1]+[2])], 3) # (0,0,1,1) 141 | target_images = tf.where(tf.tile(target_masks, [1,1,1,target_shape[3]]), target_images, target_background) 142 | else: 143 | # retain depth only 144 | target_images = tf.slice(target_images, [0,0,0,3], [-1,-1,-1,1]) 145 | 146 | target_images = tf.concat([target_images, image.convert_to_real_mask(target_masks)], 3) 147 | 148 | # create prefetching tensor 149 | 150 | num_shapes = len(shape_list) 151 | min_queue_examples = max(1, int(num_shapes * 0.01)) 152 | 153 | tensor_data = [shape_name, source_image, target_images, target_masks, target_angles] 154 | print('name: ', shape_name) 155 | print('source: ', source_image) 156 | print('target: ', target_images) 157 | print('mask: ', target_masks) 158 | print('angle: ', target_angles) 159 | 160 | if shuffle: 161 | num_preprocess_threads = 12 162 | batch_data = tf.train.shuffle_batch( 163 | tensor_data, 164 | batch_size=batch_size, 165 | num_threads=num_preprocess_threads, 166 | capacity=min_queue_examples + 3 * batch_size, 167 | min_after_dequeue=min_queue_examples) 168 | else: 169 | num_preprocess_threads = 1 170 | batch_data = tf.train.batch( 171 | tensor_data, 172 | batch_size=batch_size, 173 | num_threads=num_preprocess_threads, 174 | capacity=min_queue_examples) 175 | 176 | name_batch = batch_data[0] 177 | source_batch = batch_data[1] 178 | target_batch = batch_data[2] 179 | target_batch = tf.reshape(target_batch, [-1]+target_batch.get_shape().as_list()[2:]) 180 | mask_batch = batch_data[3] 181 | mask_batch = tf.reshape(mask_batch, [-1]+mask_batch.get_shape().as_list()[2:]) 182 | angle_batch = batch_data[4] 183 | angle_batch = tf.reshape(angle_batch, [-1]+angle_batch.get_shape().as_list()[2:]) 184 | 185 | print('*******************************') 186 | print('name: ', name_batch) 187 | print('source: ', source_batch) 188 | print('target: ', target_batch) 189 | print('mask: ', mask_batch) 190 | print('angle: ', angle_batch) 191 | 192 | return name_batch, source_batch, target_batch, mask_batch, angle_batch, num_shapes 193 | 194 | def load_train_data(config, views, batch_size=-1): 195 | 196 | print("Loading training data...") 197 | 198 | shape_list_file = open(os.path.join(config.data_dir, 'train-list.txt'), 'r') 199 | shape_list = shape_list_file.read().splitlines() 200 | shape_list_file.close() 201 | 202 | return load_data(config, views, shape_list, shuffle=True, batch_size=batch_size) 203 | 204 | # def load_test_data(config, views, batch_size=-1): 205 | 206 | # print("Loading testing data...") 207 | 208 | # shape_list_file = open(os.path.join(config.data_dir, 'test-list.txt'), 'r') 209 | # shape_list = shape_list_file.read().splitlines() 210 | # shape_list_file.close() 211 | 212 | # return load_data(config, views, shape_list, shuffle=False, batch_size=batch_size) 213 | 214 | def load_test_data(config, views, batch_size=-1): 215 | 216 | print("Loading testing data...") 217 | 218 | shape_list_file = open(os.path.join(config.data_dir, 'test-list.txt'), 'r') 219 | shape_list = shape_list_file.read().splitlines() 220 | shape_list_file.close() 221 | # import pudb; pu.db 222 | test_path = Path(config.check_dir)/'results'/'03001627' 223 | exists_list = [x for x in test_path.iterdir() if x.is_dir()] 224 | extra_path = Path(config.test_dir)/'results'/'03001627' 225 | if extra_path.exists(): 226 | extra_list = [x for x in extra_path.iterdir() if x.is_dir()] 227 | exists_list.extend(extra_list) 228 | for item in exists_list: 229 | shape = item.parts[-2] + '/' + item.name 230 | if shape in shape_list: 231 | shape_list.remove(shape) 232 | 233 | return load_data(config, views, shape_list, shuffle=False, batch_size=batch_size) 234 | 235 | def load_encode_data(config, views, batch_size=-1): 236 | 237 | print("Loading encoding data...") 238 | 239 | shape_list_file = open(os.path.join(config.data_dir, 'list.txt'), 'r') 240 | shape_list = shape_list_file.read().splitlines() 241 | shape_list_file.close() 242 | 243 | return load_data(config, views, shape_list, shuffle=False, batch_size=batch_size) 244 | 245 | def load_validate_data(config, views, batch_size=-1): 246 | 247 | print("Loading validation data...") 248 | 249 | shape_list_file = open(os.path.join(config.data_dir, 'validate-list.txt'), 'r') 250 | shape_list = shape_list_file.read().splitlines() 251 | shape_list_file.close() 252 | 253 | return load_data(config, views, shape_list, shuffle=False, batch_size=batch_size) 254 | 255 | def write_bin_data(file_name, data): 256 | 257 | path = os.path.dirname(file_name) 258 | if not os.path.exists(path): 259 | os.makedirs(path) 260 | data.tofile(file_name) 261 | 262 | def write_pfm_data(file_name, data): 263 | 264 | path = os.path.dirname(file_name) 265 | if not os.path.exists(path): 266 | os.makedirs(path) 267 | file = open(file_name, 'wb') 268 | 269 | if data.shape[2] == 1: 270 | file.write('Pf\n') 271 | elif data.shape[2] == 3: 272 | file.write('PF\n') 273 | else: 274 | raise ValueError('incorrect number of channels') 275 | 276 | file.write(('%d %d\n' % (data.shape[1], data.shape[0]))) 277 | file.write('-1.0\n') 278 | 279 | data = np.flipud(data) # PFM format stores pixels from bottom to top... 280 | data.tofile(file) 281 | 282 | file.close() -------------------------------------------------------------------------------- /monnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is part of the Sketch Modeling project. 3 | 4 | Copyright (c) 2017 5 | -Zhaoliang Lun (author of the code) / UMass-Amherst 6 | 7 | This is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | This software is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with this software. If not, see . 19 | """ 20 | 21 | 22 | import tensorflow as tf 23 | import numpy as np 24 | 25 | import tensorflow.contrib.framework as tf_framework 26 | 27 | import time 28 | import os 29 | import math 30 | 31 | import data 32 | import image 33 | import network 34 | import layer 35 | import loss 36 | import reproject as rp 37 | import view as vw 38 | 39 | class MonNet(object): 40 | 41 | def __init__(self, config): 42 | self.config = config 43 | 44 | def build_network(self, names, sources, targets, masks, angles, views, is_training=False, is_validation=False, is_testing=False, is_encoding=False): 45 | """ 46 | input: 47 | names : n x String shape names 48 | sources : n x H x W x C source images 49 | targets : (n*m) x H x W x C target images in m views (ground-truth) 50 | masks : (n*m) x H x W x 1 target boolean masks in m views (ground-truth) 51 | angles : (n*m) x 4 viewing angle parameters (m=1 for continuous view prediction) 52 | views : vw.Views view points information 53 | is_training : boolean whether it is in training routine 54 | is_validation : boolean whether it is handling validation data set 55 | is_testing : boolean whether it is in testing routine 56 | is_encoding : boolean whether it is encoding input 57 | """ 58 | 59 | print('Building network...') 60 | 61 | source_size = sources.get_shape().as_list() 62 | if self.config.continuous_view: 63 | num_output_views = 1 64 | else: 65 | num_output_views = views.num_views 66 | 67 | # scope names 68 | 69 | var_scope_G = 'G_net' 70 | var_scope_D = 'D_net' 71 | bn_scope_G = 'G_bn' 72 | bn_scope_D = 'D_bn' 73 | train_summary_G_name = 'train_summary_G' 74 | train_summary_D_name = 'train_summary_D' 75 | valid_summary_name = 'valid_summary' 76 | 77 | # generator 78 | 79 | num_channels = targets.get_shape()[3].value 80 | if not self.config.continuous_view: 81 | with tf.variable_scope(var_scope_G): 82 | with tf_framework.arg_scope(layer.unet_scopes(bn_scope_G)): 83 | preds, features = network.generateUNet(sources, num_output_views, num_channels) # (n*m) x H x W x C ; n x D 84 | else: 85 | with tf.variable_scope(var_scope_G): 86 | with tf_framework.arg_scope(layer.cnet_scopes(bn_scope_G)): 87 | preds, features = network.generateCNet(sources, angles, num_channels) # n x H x W x C ; n x D 88 | 89 | if is_encoding: 90 | self.encode_names = names 91 | self.encode_features = features 92 | return # all stuffs below are irrelevant to encoding pass 93 | 94 | # extract prediction contents 95 | 96 | preds_content = tf.slice(preds, [0,0,0,0], [-1,-1,-1,num_channels-1]) 97 | preds_mask = tf.slice(preds, [0,0,0,num_channels-1], [-1,-1,-1,1]) 98 | preds = image.apply_mask(preds_content, preds_mask) 99 | targets_content = tf.slice(targets, [0,0,0,0], [-1,-1,-1,num_channels-1]) 100 | targets_mask = tf.slice(targets, [0,0,0,num_channels-1], [-1,-1,-1,1]) 101 | targets = image.apply_mask(targets_content, targets_mask) 102 | if self.config.predict_normal: 103 | preds_normal = tf.slice(preds_content, [0,0,0,0], [-1,-1,-1,3]) 104 | preds_depth = tf.slice(preds_content, [0,0,0,3], [-1,-1,-1,1]) 105 | targets_normal = tf.slice(targets_content, [0,0,0,0], [-1,-1,-1,3]) 106 | targets_depth = tf.slice(targets_content, [0,0,0,3], [-1,-1,-1,1]) 107 | else: 108 | preds_depth = preds_content 109 | preds_normal = tf.tile(tf.zeros_like(preds_depth), [1,1,1,3]) 110 | targets_depth = targets_content 111 | targets_normal = tf.tile(tf.zeros_like(targets_depth), [1,1,1,3]) 112 | 113 | # expand tensors 114 | 115 | sources_expanded = tf.reshape(tf.tile(sources, [1,num_output_views,1,1]),[-1,source_size[1],source_size[2],source_size[3]]) # (n*m) x H x W x C 116 | 117 | names_expanded = tf.reshape(tf.tile(tf.expand_dims(names,1),[1,num_output_views]),[-1]) 118 | names_suffix = ["--%d" % view for batch in range(source_size[0]) for view in range(num_output_views)] 119 | names_expanded = tf.reduce_join([names_expanded, names_suffix], 0) 120 | self.names = names_expanded 121 | 122 | # discriminator 123 | 124 | if not self.config.no_adversarial: 125 | with tf.variable_scope(var_scope_D): 126 | with tf_framework.arg_scope(layer.unet_scopes(bn_scope_D)): 127 | disc_data = tf.concat([targets, preds], 0) 128 | disc_data = tf.concat([tf.concat([sources_expanded, sources_expanded], 0), disc_data], 3) # HACK: insert input data for discrimination in UNet 129 | probs = network.discriminate(disc_data) # (n*m*2) 130 | 131 | # losses 132 | 133 | # NOTE: learning hyper-parameters 134 | lambda_p = 1.0 # image loss 135 | lambda_a = 0.01 # adversarial loss 136 | 137 | dl = loss.compute_depth_loss(preds_depth, targets_depth, masks) 138 | nl = loss.compute_normal_loss(preds_normal, targets_normal, masks) 139 | ml = loss.compute_mask_loss(preds_mask, targets_mask) 140 | loss_g_p = dl + nl + ml 141 | 142 | if self.config.no_adversarial: 143 | loss_g_a = 0.0 144 | loss_d_r = 0.0 145 | loss_d_f = 0.0 146 | else: 147 | probs_targets, probs_preds = tf.split(probs, 2, axis=0) # (n*m) 148 | loss_g_a = tf.reduce_sum(-tf.log(tf.maximum(probs_preds, 1e-6))) 149 | loss_d_r = tf.reduce_sum(-tf.log(tf.maximum(probs_targets, 1e-6))) 150 | loss_d_f = tf.reduce_sum(-tf.log(tf.maximum(1.0-probs_preds, 1e-6))) 151 | 152 | loss_G = loss_g_p * lambda_p + loss_g_a * lambda_a 153 | loss_D = loss_d_r + loss_d_f 154 | 155 | if is_validation: 156 | self.valid_losses = tf.stack([loss_G, loss_g_p, loss_g_a, loss_D, loss_d_r, loss_d_f]) 157 | self.valid_images = tf.stack([ 158 | image.encode_raw_batch_images(preds), 159 | image.encode_raw_batch_images(targets), 160 | image.encode_raw_batch_images(preds_normal), 161 | image.encode_raw_batch_images(preds_depth), 162 | image.encode_raw_batch_images(preds_mask)]) 163 | self.valid_summary_losses = tf.placeholder(tf.float32, shape=self.valid_losses.get_shape()) 164 | vG_all, vG_p, vG_a, vD_all, vD_r, vD_f = tf.unstack(self.valid_summary_losses) 165 | tf.summary.scalar('vG_all', vG_all, collections=[valid_summary_name]) 166 | tf.summary.scalar('vG_p', vG_p, collections=[valid_summary_name]) 167 | tf.summary.scalar('vG_a', vG_a, collections=[valid_summary_name]) 168 | tf.summary.scalar('vD_all', vD_all, collections=[valid_summary_name]) 169 | tf.summary.scalar('vD_r', vD_r, collections=[valid_summary_name]) 170 | tf.summary.scalar('vD_f', vD_f, collections=[valid_summary_name]) 171 | self.valid_summary_op = tf.summary.merge_all(valid_summary_name) 172 | return # all stuffs below are irrelevant to validation pass 173 | 174 | self.train_losses_G = tf.stack([loss_G, loss_g_p, loss_g_a]) 175 | self.train_losses_D = tf.stack([loss_D, loss_d_r, loss_d_f]) 176 | tf.summary.scalar('G_all', loss_G, collections=[train_summary_G_name]) 177 | tf.summary.scalar('G_p', loss_g_p, collections=[train_summary_G_name]) 178 | tf.summary.scalar('G_a', loss_g_a, collections=[train_summary_G_name]) 179 | tf.summary.scalar('D_all', loss_D, collections=[train_summary_D_name]) 180 | tf.summary.scalar('D_r', loss_d_r, collections=[train_summary_D_name]) 181 | tf.summary.scalar('D_f', loss_d_f, collections=[train_summary_D_name]) 182 | 183 | # statistics on variables 184 | 185 | all_vars = tf.trainable_variables() 186 | all_vars_G = [var for var in all_vars if var_scope_G in var.name] 187 | all_vars_D = [var for var in all_vars if var_scope_D in var.name] 188 | #print('Num all vars: %d' % len(all_vars)) 189 | #print('Num vars on G net: %d' % len(all_vars_G)) 190 | #print('Num vars on D net: %d' % len(all_vars_D)) 191 | num_params_G = 0 192 | num_params_D = 0 193 | # print('G vars:') 194 | for var in all_vars_G: 195 | num_params_G += np.prod(var.get_shape().as_list()) 196 | # print(var.name, var.get_shape().as_list()) 197 | # print('D vars:') 198 | for var in all_vars_D: 199 | num_params_D += np.prod(var.get_shape().as_list()) 200 | # print(var.name, var.get_shape().as_list()) 201 | #print('Num all params: %d + %d = %d' % (num_params_G, num_params_D, num_params_G+num_params_D)) 202 | #input('pause') 203 | 204 | # optimization 205 | 206 | # NOTE: learning hyper-parameters 207 | init_learning_rate = 0.0001 208 | adam_beta1 = 0.9 209 | adam_beta2 = 0.999 210 | opt_step = tf.Variable(0, trainable=False) 211 | learning_rate = tf.train.exponential_decay(init_learning_rate, global_step=opt_step, decay_steps=10000, decay_rate=0.96, staircase=True) 212 | 213 | opt_G = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=adam_beta1, beta2=adam_beta2, name='ADAM_G') 214 | opt_D = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=adam_beta1, beta2=adam_beta2, name='ADAM_D') 215 | # opt_G = tf.train.GradientDescentOptimizer(learning_rate=learning_rate, name='SGD_G') 216 | # opt_D = tf.train.GradientDescentOptimizer(learning_rate=learning_rate, name='SGD_D') 217 | 218 | grad_G = opt_G.compute_gradients(loss_G, var_list=all_vars_G, colocate_gradients_with_ops=True) 219 | self.grad_G_placeholder = [(tf.placeholder(tf.float32, shape=grad[1].get_shape()), grad[1]) for grad in grad_G if grad[0] is not None] 220 | self.grad_G_list = [grad[0] for grad in grad_G if grad[0] is not None] 221 | self.update_G_op = opt_G.apply_gradients(self.grad_G_placeholder, global_step=opt_step) # only update opt_step in G net 222 | 223 | if not self.config.no_adversarial: 224 | grad_D = opt_D.compute_gradients(loss_D, var_list=all_vars_D, colocate_gradients_with_ops=True) 225 | self.grad_D_placeholder = [(tf.placeholder(tf.float32, shape=grad[1].get_shape()), grad[1]) for grad in grad_D if grad[0] is not None] 226 | self.grad_D_list = [grad[0] for grad in grad_D if grad[0] is not None] 227 | self.update_D_op = opt_D.apply_gradients(self.grad_D_placeholder) 228 | 229 | # visualization stuffs 230 | 231 | sources_original, sources_flipped = tf.split(sources_expanded, 2, axis=3) 232 | if len(self.config.sketch_views) == 1: # single input 233 | sources_front = sources_original 234 | sources_side = tf.ones_like(sources_front) # fake side sketch 235 | sources_top = tf.ones_like(sources_front) # fake top sketch 236 | elif len(self.config.sketch_views) == 2: # double input 237 | sources_front, sources_side = tf.split(sources_original, 2, axis=3) 238 | sources_top = tf.ones_like(sources_front) # fake top sketch 239 | elif len(self.config.sketch_views) == 3: # triple input 240 | sources_front, sources_side, sources_top = tf.split(sources_original, 3, axis=3) 241 | if sources_front.get_shape()[3].value == 1 and targets.get_shape()[3].value == 4: 242 | alpha_front = tf.ones_like(sources_front) 243 | alpha_side = tf.ones_like(sources_side) 244 | alpha_top = tf.ones_like(sources_top) 245 | rgb_front = image.convert_to_rgb(sources_front, channels=3) 246 | rgb_side = image.convert_to_rgb(sources_side, channels=3) 247 | rgb_top = image.convert_to_rgb(sources_top, channels=3) 248 | sources_front = tf.concat([rgb_front, alpha_front], 3) 249 | sources_side = tf.concat([rgb_side, alpha_side], 3) 250 | sources_top = tf.concat([rgb_top, alpha_top], 3) 251 | 252 | input_row = tf.concat([sources_front, sources_side], 2) 253 | output_row = tf.concat([targets, preds], 2) 254 | 255 | result_tile = tf.concat([input_row, output_row], 1) 256 | result_tile = image.saturate_image(image.unnormalize_image(result_tile)) 257 | 258 | tf.summary.image('result', result_tile, 12, [train_summary_G_name]) 259 | 260 | self.train_summary_G_op = tf.summary.merge_all(train_summary_G_name) 261 | self.train_summary_D_op = tf.summary.merge_all(train_summary_D_name) 262 | 263 | # output images 264 | 265 | num_sketch_views = len(self.config.sketch_views) 266 | if num_sketch_views==1: 267 | all_input_row = sources_front 268 | elif num_sketch_views==2: 269 | all_input_row = tf.concat([sources_front, sources_side], 2) 270 | elif num_sketch_views==3: 271 | all_input_row = tf.concat([sources_front, sources_side, sources_top], 2) 272 | img_input = image.saturate_image(image.unnormalize_image(all_input_row, maxval=65535.0), dtype=tf.uint16) 273 | img_gt = image.saturate_image(image.unnormalize_image(targets, maxval=65535.0), dtype=tf.uint16) 274 | img_output = image.saturate_image(image.unnormalize_image(preds, maxval=65535.0), dtype=tf.uint16) 275 | png_input = image.encode_batch_images(img_input) 276 | png_gt = image.encode_batch_images(img_gt) 277 | png_output = image.encode_batch_images(img_output) 278 | 279 | img_normal = image.saturate_image(image.unnormalize_image(preds_normal, maxval=65535.0), dtype=tf.uint16) 280 | img_depth = image.saturate_image(image.unnormalize_image(preds_depth, maxval=65535.0), dtype=tf.uint16) 281 | img_mask = image.saturate_image(image.unnormalize_image(preds_mask, maxval=65535.0), dtype=tf.uint16) 282 | png_normal = image.encode_batch_images(img_normal) 283 | png_depth = image.encode_batch_images(img_depth) 284 | png_mask = image.encode_batch_images(img_mask) 285 | self.pngs = tf.stack([png_input, png_gt, png_output, png_normal, png_depth, png_mask]) 286 | 287 | # output results 288 | 289 | pixel_shape = preds.get_shape().as_list() 290 | num_pixels = np.prod(pixel_shape[1:]) 291 | self.errors = tf.reduce_sum(tf.abs(preds-targets), [1,2,3]) / num_pixels # just a quick check 292 | self.results = preds 293 | 294 | # batch normalization 295 | 296 | bn_G_collection = tf.get_collection(bn_scope_G) 297 | bn_D_collection = tf.get_collection(bn_scope_D) 298 | self.bn_G_op = tf.group(*bn_G_collection) 299 | self.bn_D_op = tf.group(*bn_D_collection) 300 | 301 | def train(self, sess, views, num_train_shapes, num_valid_shapes): 302 | 303 | print('Training...') 304 | 305 | ckpt = tf.train.get_checkpoint_state(self.config.train_dir) 306 | init_op = tf.global_variables_initializer() 307 | sess.run(init_op) 308 | if ckpt and ckpt.model_checkpoint_path: 309 | self.saver = tf.train.Saver(keep_checkpoint_every_n_hours=10.0, max_to_keep=2) 310 | self.saver.restore(sess, ckpt.model_checkpoint_path) 311 | try: 312 | self.step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) 313 | except ValueError: 314 | self.step = 0 315 | else: 316 | self.saver = tf.train.Saver(tf.global_variables(), keep_checkpoint_every_n_hours=10.0, max_to_keep=2) 317 | self.step = 0 318 | 319 | coord = tf.train.Coordinator() 320 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 321 | self.summarizer = tf.summary.FileWriter(self.config.train_dir, sess.graph) 322 | 323 | print_interval = 40 // self.config.batch_size # steps 324 | update_interval = 40 // self.config.batch_size # steps 325 | summary_interval = 200 # steps 326 | validate_interval = 200 # steps 327 | output_interval = 1000 # steps 328 | checkpoint_interval = 1000 # steps 329 | 330 | print('Start iterating...') 331 | 332 | start_time = time.time() 333 | 334 | train_D_net = not self.config.no_adversarial 335 | batch_grad_G_list = None 336 | batch_grad_D_list = None 337 | batch_losses_G = None 338 | batch_losses_D = None 339 | step_losses_G = None 340 | step_losses_D = None 341 | 342 | while True: 343 | 344 | # compute epochs 345 | 346 | epochs = 1.0*(self.step+1)*self.config.batch_size/num_train_shapes 347 | do_print = ((self.step+1) % print_interval == 0) 348 | do_update = ((self.step+1) % update_interval == 0) 349 | do_validate = ((self.step+1) % validate_interval == 0) 350 | do_summary = ((self.step+1) % summary_interval == 0) 351 | do_checkpoint = ((self.step+1) % checkpoint_interval == 0) 352 | do_output = ((self.step+1) % output_interval == 0) 353 | 354 | # training networks 355 | 356 | step_G_list = sess.run(self.grad_G_list + [self.bn_G_op, self.train_losses_G]) 357 | step_grad_G_list = step_G_list[:-2] 358 | step_losses_G = step_G_list[-1] / self.config.batch_size 359 | batch_grad_G_list = self.cumulate_gradients(batch_grad_G_list, step_grad_G_list) 360 | 361 | if train_D_net: 362 | step_D_list = sess.run(self.grad_D_list + [self.bn_D_op, self.train_losses_D]) 363 | step_grad_D_list = step_D_list[:-2] 364 | step_losses_D = step_D_list[-1] / self.config.batch_size 365 | batch_grad_D_list = self.cumulate_gradients(batch_grad_D_list, step_grad_D_list) 366 | else: 367 | if step_losses_D is None: 368 | step_losses_D = [0.0, 0.0, 0.0] 369 | 370 | batch_losses_G = step_losses_G if batch_losses_G is None else batch_losses_G+step_losses_G 371 | batch_losses_D = step_losses_D if batch_losses_D is None else batch_losses_D+step_losses_D 372 | 373 | # update gradients 374 | 375 | if do_update: 376 | grad_G_dict = {} 377 | for k in range(len(self.grad_G_placeholder)): 378 | grad_G_dict[self.grad_G_placeholder[k][0]] = batch_grad_G_list[k] / update_interval 379 | sess.run(self.update_G_op, feed_dict=grad_G_dict) 380 | batch_grad_G_list = None 381 | 382 | if train_D_net: 383 | grad_D_dict = {} 384 | for k in range(len(self.grad_D_placeholder)): 385 | grad_D_dict[self.grad_D_placeholder[k][0]] = batch_grad_D_list[k] / update_interval 386 | sess.run(self.update_D_op, feed_dict=grad_D_dict) 387 | batch_grad_D_list = None 388 | 389 | if not self.config.no_adversarial: 390 | batch_losses_G = batch_losses_G / update_interval 391 | if batch_losses_D is not None: 392 | batch_losses_D = batch_losses_D / update_interval 393 | train_D_net = (batch_losses_D[0] > batch_losses_G[2] * 0.1) # NOTE: subscript 394 | batch_losses_G = None 395 | batch_losses_D = None 396 | 397 | # validation 398 | 399 | if do_validate: 400 | self.validate_loss(sess, num_valid_shapes) 401 | 402 | if do_output: 403 | self.validate_output(sess, num_valid_shapes, epochs) 404 | 405 | # log 406 | 407 | if do_summary: 408 | summary_G_str = sess.run(self.train_summary_G_op) 409 | self.summarizer.add_summary(summary_G_str, self.step) 410 | if train_D_net: 411 | summary_D_str = sess.run(self.train_summary_D_op) 412 | self.summarizer.add_summary(summary_D_str, self.step) 413 | 414 | if do_checkpoint: 415 | self.saver.save(sess, os.path.join(self.config.train_dir,'model.ckpt'), global_step=self.step+1) 416 | 417 | if do_print: 418 | now_time = time.time() 419 | batch_duration = now_time - start_time 420 | start_time = now_time 421 | log_str_1 = 'Step %7d: %5.1f sec, epoch: %7.2f, ' % (self.step+1, batch_duration, epochs) 422 | log_str_2 = 'losses: %7.3g, %7.3g, %7.3g, %7.3g, %7.3g, %7.3g;' % \ 423 | (step_losses_G[0], step_losses_G[1], step_losses_G[2], step_losses_D[0], step_losses_D[1], step_losses_D[2]) 424 | print(log_str_1, end='') 425 | print(log_str_2) 426 | log_file_name = os.path.join(self.config.train_dir,'log.txt') 427 | with open(log_file_name, 'a') as log_file: 428 | log_file.write(log_str_1+log_str_2+'\n') 429 | 430 | if epochs >= self.config.max_epochs: 431 | break 432 | 433 | self.step += 1 434 | 435 | coord.request_stop() 436 | coord.join(threads) 437 | 438 | def test(self, sess, views, num_shapes): 439 | 440 | print('Testing...') 441 | 442 | self.saver = tf.train.Saver() 443 | ckpt = tf.train.get_checkpoint_state(self.config.train_dir) 444 | if ckpt and ckpt.model_checkpoint_path: 445 | self.saver.restore(sess, ckpt.model_checkpoint_path) 446 | try: 447 | self.step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) 448 | except ValueError: 449 | self.step = 0 450 | else: 451 | print('Cannot find any checkpoint file') 452 | return 453 | 454 | coord = tf.train.Coordinator() 455 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 456 | self.summarizer = tf.summary.FileWriter(self.config.test_dir, sess.graph) 457 | 458 | output_count = 0 459 | output_prefix = 'dn14' 460 | output_images_folder = 'images' 461 | output_results_folder = 'results' 462 | 463 | log_file_name = os.path.join(self.config.test_dir,'log.txt') 464 | log_file = open(log_file_name, 'a') 465 | 466 | started = False 467 | finished = False 468 | last_shape_name = '' 469 | last_view_name = '' 470 | while not finished: 471 | names,results,errors,images = sess.run([self.names, self.results, self.errors, self.pngs]) 472 | for k in range(len(names)): 473 | shape_name, view_name = names[k].decode('utf8').split('--') 474 | if last_shape_name == shape_name: 475 | view_name = ('%s' % (int(last_view_name)+1)) 476 | last_shape_name = shape_name 477 | last_view_name = view_name 478 | print('Processed %d: %s--%s %f' % (output_count, shape_name, view_name, errors[k])) 479 | 480 | if view_name == '0' and started: 481 | log_file.write('\n') 482 | started = True 483 | log_file.write('%6f ' % errors[k]) 484 | 485 | # export images 486 | name_input = os.path.join(self.config.test_dir, output_images_folder, shape_name, 'input.png') 487 | image.write_image(name_input, images[0, k]) 488 | name_gt = os.path.join(self.config.test_dir, output_images_folder, shape_name, ('gt-'+output_prefix+'--'+view_name+'.png')) 489 | name_output = os.path.join(self.config.test_dir, output_images_folder, shape_name, ('pred-'+output_prefix+'--'+view_name+'.png')) 490 | image.write_image(name_gt, images[1, k]) 491 | image.write_image(name_output, images[2, k]) 492 | 493 | name_normal = os.path.join(self.config.test_dir, output_images_folder, shape_name, ('normal-'+output_prefix+'--'+view_name+'.png')) 494 | name_depth = os.path.join(self.config.test_dir, output_images_folder, shape_name, ('depth-'+output_prefix+'--'+view_name+'.png')) 495 | name_mask = os.path.join(self.config.test_dir, output_images_folder, shape_name, ('mask-'+output_prefix+'--'+view_name+'.png')) 496 | image.write_image(name_normal, images[3, k]) 497 | image.write_image(name_depth, images[4, k]) 498 | image.write_image(name_mask, images[5, k]) 499 | 500 | # export results 501 | name_output = os.path.join(self.config.test_dir, output_results_folder, shape_name, (output_prefix+'-'+view_name+'.png')) 502 | image.write_image(name_output, images[2, k]) 503 | 504 | # check termination 505 | output_count += 1 506 | if output_count >= num_shapes * views.num_views: 507 | finished = True 508 | break 509 | 510 | coord.request_stop() 511 | coord.join(threads) 512 | 513 | def encode(self, sess, views, num_shapes): 514 | 515 | print('Encoding...') 516 | 517 | self.saver = tf.train.Saver() 518 | ckpt = tf.train.get_checkpoint_state(self.config.train_dir) 519 | if ckpt and ckpt.model_checkpoint_path: 520 | self.saver.restore(sess, ckpt.model_checkpoint_path) 521 | self.step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) 522 | else: 523 | print('Cannot find any checkpoint file') 524 | return 525 | 526 | coord = tf.train.Coordinator() 527 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 528 | self.summarizer = tf.summary.FileWriter(self.config.encode_dir, sess.graph) 529 | 530 | output_count = 0 531 | output_folder = 'features' 532 | 533 | finished = False 534 | while not finished: 535 | names,features = sess.run([self.encode_names, self.encode_features]) 536 | for k in range(len(names)): 537 | shape_name = names[k].decode('utf8') 538 | print('Processed %d: %s' % (output_count, shape_name)) 539 | 540 | # export results 541 | name_output = os.path.join(self.config.encode_dir, output_folder, (shape_name+'.bin')) 542 | data.write_bin_data(name_output, features[k]) 543 | 544 | # check termination 545 | output_count += 1 546 | if output_count >= num_shapes: 547 | finished = True 548 | break 549 | 550 | coord.request_stop() 551 | coord.join(threads) 552 | 553 | def validate_loss(self, sess, num_shapes): 554 | 555 | num_processed_shapes = 0 556 | cum_losses = None 557 | while num_processed_shapes < num_shapes: 558 | losses = sess.run(self.valid_losses) 559 | losses = np.array(losses) 560 | cum_losses = losses if cum_losses is None else cum_losses+losses 561 | num_processed_shapes += self.config.batch_size 562 | cum_losses /= num_processed_shapes 563 | 564 | print('===== validation loss: %.3g' % cum_losses[0]) 565 | 566 | summary_str = sess.run(self.valid_summary_op, feed_dict={self.valid_summary_losses:cum_losses}) 567 | self.summarizer.add_summary(summary_str, self.step) 568 | 569 | def validate_output(self, sess, num_shapes, epochs): 570 | 571 | print('===== validation output') 572 | valid_results_folder = 'epoch-%.2f' % epochs 573 | names, images = sess.run([self.names, self.valid_images]) 574 | 575 | for k in range(len(names)): 576 | shape_name, view_name = names[k].decode('utf8').split('--') 577 | if view_name == '0': 578 | print(shape_name) 579 | 580 | name_output = os.path.join(self.config.train_dir, valid_results_folder, shape_name, ('output--'+view_name+'.png')) 581 | name_gt = os.path.join(self.config.train_dir, valid_results_folder, shape_name, ('gt--'+view_name+'.png')) 582 | image.write_image(name_output, images[0, k]) 583 | image.write_image(name_gt, images[1, k]) 584 | 585 | name_normal = os.path.join(self.config.train_dir, valid_results_folder, shape_name, ('normal--'+view_name+'.png')) 586 | name_depth = os.path.join(self.config.train_dir, valid_results_folder, shape_name, ('depth--'+view_name+'.png')) 587 | name_mask = os.path.join(self.config.train_dir, valid_results_folder, shape_name, ('mask--'+view_name+'.png')) 588 | image.write_image(name_normal, images[2, k]) 589 | image.write_image(name_depth, images[3, k]) 590 | image.write_image(name_mask, images[4, k]) 591 | 592 | # loop over all remaining shapes in the queue... 593 | num_processed_shapes = self.config.batch_size 594 | while num_processed_shapes < num_shapes: 595 | sess.run(self.names) 596 | num_processed_shapes += self.config.batch_size 597 | 598 | def cumulate_gradients(self, cum_grads, grads): 599 | if cum_grads is None: 600 | cum_grads = grads 601 | else: 602 | for k in range(len(grads)): 603 | cum_grads[k] += grads[k] 604 | return cum_grads --------------------------------------------------------------------------------