├── .DS_Store
├── img
└── img.png
├── get-list.py
├── README.md
├── view.py
├── layer.py
├── image.py
├── main.py
├── reproject.py
├── network.py
├── loss.py
├── data.py
└── monnet.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yueeey/sketcch3D/HEAD/.DS_Store
--------------------------------------------------------------------------------
/img/img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yueeey/sketcch3D/HEAD/img/img.png
--------------------------------------------------------------------------------
/get-list.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | p = Path('/vol/research/zy/dataSets/shapeMVD/Chair/hires/03001627')
4 | folder_list = [x for x in p.iterdir() if x.is_dir()]
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Towards Practical Sketch-Based 3D Shape Generation.
2 |
3 | ## Contents
4 |
5 | - [Introduction](#Introduction)
6 | - [Requirements](#Requirements)
7 | - [Download Dataset](#Download-Dataset)
8 | - [Results](#Results)
9 |
10 | ## Introduction
11 |
12 | This repository contains the Pytorch implementation of [Towards Practical Sketch-Based 3D Shape Generation](https://ieeexplore.ieee.org/document/9272370).
13 |
14 | You can find detailed usage instructions for training and evaluation below.
15 |
16 | If you use our code or dataset, please cite our work:
17 |
18 | @ARTICLE{sketch3d2020,
19 | author={Zhong, Yue and Qi, Yonggang and Gryaditskaya, Yulia and Zhang, Honggang and Song, Yi-Zhe},
20 | journal={IEEE Transactions on Circuits and Systems for Video Technology},
21 | title={Towards Practical Sketch-Based 3D Shape Generation: The Role of Professional Sketches},
22 | year={2021},
23 | volume={31},
24 | number={9},
25 | pages={3518-3528},
26 | doi={10.1109/TCSVT.2020.3040900}
27 | }
28 |
29 | ## Requirements
30 |
31 | First you have to make sure that you have all dependencies in place.
32 | The simplest way to do so, is to use [anaconda](https://www.anaconda.com/).
33 | sss
34 | Please refer the README file in each sub-task for detailed instruction.
35 | s
36 | ## Download Dataset
37 |
38 | Download dataset is easy. Directly download from [Dataset](https://pan.baidu.com/s/1wpf6Tc7h55TN6bdUYXQsPQ) with code: fhp7.
39 |
40 | Most of our experiments are conducted on the modelsfrom a chair category of the ShapeNetCore V2. We selected these categories guided by the next principles: 1) Easy to sketch. 2) Generality. 3) View differentiability. 4) Shape genius higher than 1. 5) Large inter-category variance. We generate three categories with distinctive styles, whichwe refer to as naive, stylized and style-unified. Please refer paper for further details.
41 |
42 |
43 | ## Results
44 |
45 | We show an improved performance of deep image modeling.
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/view.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is part of the Sketch Modeling project.
3 |
4 | Copyright (c) 2017
5 | -Zhaoliang Lun (author of the code) / UMass-Amherst
6 |
7 | This is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | This software is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with this software. If not, see .
19 | """
20 |
21 |
22 | #import tensorflow as tf
23 | import numpy as np
24 |
25 | class Views(object):
26 |
27 | def __init__(self, filename, num_views=-1):
28 | """
29 | self.views: V x 3
30 | self.groups G x v
31 | """
32 |
33 | f = open(filename, 'r')
34 |
35 | f.readline() # OFF
36 | self.num_views, self.num_groups, num_edges = map(int, f.readline().split())
37 |
38 | view_data = []
39 | for view_id in range(self.num_views):
40 | view_data.append(list(map(float, f.readline().split())))
41 | self.views = np.array(view_data)
42 |
43 | group_data = []
44 | for group_id in range(self.num_groups):
45 | group_data.append(list(map(int, f.readline().split()[1:])))
46 | self.groups = np.array(group_data)
47 |
48 | f.close()
49 |
50 | if num_views >= 0: # select views
51 | self.num_views = num_views
52 | self.num_groups = 0
53 | self.views = self.views[:self.num_views]
54 | self.groups = self.groups[:0]
55 |
56 | self.num_edges = self.num_views+self.num_groups-2
57 | self.edge_size = 2
58 |
59 | # HACK: minimal data for local testing
60 | #self.num_views = 3
61 | #self.num_groups = 0
62 | #self.num_edges = 1
63 | #self.edge_size = 2
64 | #self.views = self.views[:self.num_views]
65 | #self.groups = self.groups[:self.num_groups]
66 |
67 | #print('Views:')
68 | #print(self.views)
69 | #print('Groups:')
70 | #print(self.groups)
71 |
72 | def view2angle(view):
73 | """
74 | input:
75 | view : 3 : (x,y,z)
76 | output:
77 | angle : 4 : (cos(theta), sin(theta), cos(phi), sin(phi))
78 | """
79 |
80 | r = np.linalg.norm(view) # sqrt(x^2+y^2+z^2)
81 | rxz = np.linalg.norm(view[[0,2]]) # sqrt(x^2+z^2)
82 | ct = view[1] / r # cos(theta) = y/r
83 | st = rxz / r # sin(theta) = sqrt(x^2+z^2)/r
84 | if rxz>0:
85 | cp = view[0] / rxz # cos(phi) = x / sqrt(x^2+z^2)
86 | sp = view[2] / rxz # sin(phi) = z / sqrt(x^2+z^2)
87 | else: # zenith point
88 | cp = 0.0
89 | sp = 0.0
90 | return [ct, st, cp, sp]
91 |
--------------------------------------------------------------------------------
/layer.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is part of the Sketch Modeling project.
3 |
4 | Copyright (c) 2017
5 | -Zhaoliang Lun (author of the code) / UMass-Amherst
6 |
7 | This is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | This software is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with this software. If not, see .
19 | """
20 |
21 |
22 | import tensorflow as tf
23 | import numpy as np
24 |
25 | import tensorflow.contrib.layers as tf_layers
26 | import tensorflow.contrib.framework as tf_framework
27 |
28 | WEIGHT_STDDEV = 0.005
29 | WEIGHT_DECAY = 0.0001
30 | BN_DECAY = 0.997
31 | BN_EPSILON = 1e-5
32 |
33 | def leaky_relu(tensor, slope=0.2):
34 | """
35 | input:
36 | tensor : input tensor of any shape
37 | output:
38 | result : output tensor having the same shape as input tensor
39 | """
40 | return tf.maximum(tensor*slope, tensor)
41 |
42 | def unet_scopes(bn_scope):
43 |
44 | bn_params = {
45 | 'is_training': True,
46 | 'decay': BN_DECAY,
47 | 'epsilon': BN_EPSILON,
48 | 'trainable': False,
49 | 'updates_collections': bn_scope,
50 | }
51 |
52 | with tf_framework.arg_scope(
53 | [tf_layers.conv2d, tf_layers.fully_connected],
54 | weights_initializer=tf.truncated_normal_initializer(stddev=WEIGHT_STDDEV),
55 | weights_regularizer=tf_layers.l2_regularizer(WEIGHT_DECAY),
56 | biases_initializer=tf.zeros_initializer(),
57 | normalizer_fn=tf_layers.batch_norm,
58 | normalizer_params=bn_params,
59 | activation_fn=leaky_relu) as scope:
60 | if bn_scope is None:
61 | return scope
62 | else:
63 | with tf_framework.arg_scope([tf_layers.batch_norm], **bn_params) as scope_with_bn:
64 | return scope_with_bn
65 |
66 | def cnet_scopes(bn_scope):
67 |
68 | with tf_framework.arg_scope(
69 | [tf_layers.conv2d, tf_layers.fully_connected],
70 | weights_initializer=tf.truncated_normal_initializer(stddev=WEIGHT_STDDEV),
71 | weights_regularizer=tf_layers.l2_regularizer(WEIGHT_DECAY),
72 | biases_initializer=tf.zeros_initializer(),
73 | normalizer_fn=None,
74 | activation_fn=leaky_relu) as scope:
75 | return scope
76 |
77 | def residual_layer(inputs, kernel, scope):
78 | """
79 | input:
80 | inputs : n x H x W x C feature maps to be passed into residual block
81 | kernel : scalar internal filter kernel size
82 | scope : string scope name
83 | output:
84 | outputs : n x H x W x C output feature maps
85 | """
86 |
87 | channels = inputs.get_shape()[3].value
88 | layer1 = tf_layers.conv2d(inputs, num_outputs=channels, kernel_size=kernel, stride=1, scope=scope+'/layer1')
89 | layer2 = tf_layers.conv2d(layer1, num_outputs=channels, kernel_size=kernel, stride=1, scope=scope+'/layer2', activation_fn=None)
90 | outputs = layer2 + inputs
91 | return outputs
92 |
93 | def unconv_layer(inputs, num_outputs, kernel_size, stride, scope, normalizer_fn=tf_layers.batch_norm, activation_fn=tf.nn.relu):
94 | """
95 | input:
96 | inputs : n x H x W x C feature maps to be passed into unconv layer
97 | num_outputs : scalar number of channels in output feature map
98 | kernel_size : scalar internal filter kernel size
99 | scope : string scope name
100 | normalizer_fn : function normalizer function
101 | activation_fn : function activation function
102 | output:
103 | outputs : n x H x W x C output feature maps
104 | """
105 |
106 | # return tf_layers.conv2d_transpose(inputs, num_outputs=num_outputs, kernel_size=kernel_size, stride=stride, scope=scope, normalizer_fn=normalizer_fn, activation_fn=activation_fn)
107 |
108 | h = inputs.get_shape()[1].value
109 | w = inputs.get_shape()[2].value
110 | c = inputs.get_shape()[3].value
111 |
112 | # upsampled = tf.image.resize_bilinear(inputs, [h*stride, w*stride])
113 | upsampled = tf.image.resize_nearest_neighbor(inputs, [h*stride, w*stride])
114 |
115 | outputs = tf_layers.conv2d(upsampled, num_outputs=num_outputs, kernel_size=kernel_size, stride=1, scope=scope, normalizer_fn=normalizer_fn, activation_fn=activation_fn)
116 |
117 | # features = tf_layers.conv2d(upsampled, num_outputs=c, kernel_size=kernel_size, stride=1, scope=scope+'/conv1')
118 | # outputs = tf_layers.conv2d(features, num_outputs=num_outputs, kernel_size=kernel_size, stride=1, scope=scope+'/conv2', normalizer_fn=normalizer_fn, activation_fn=activation_fn)
119 |
120 | return outputs
121 |
--------------------------------------------------------------------------------
/image.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is part of the Sketch Modeling project.
3 |
4 | Copyright (c) 2017
5 | -Zhaoliang Lun (author of the code) / UMass-Amherst
6 |
7 | This is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | This software is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with this software. If not, see .
19 | """
20 |
21 |
22 | import tensorflow as tf
23 | import numpy as np
24 | from scipy import ndimage
25 |
26 | import os
27 |
28 | ########################### image processing ###########################
29 |
30 | def normalize_image(image):
31 | # normalize to [-1.0, 1.0]
32 | if image.dtype == tf.uint8:
33 | return tf.to_float(image)/127.5-1.0
34 | elif image.dtype == tf.uint16:
35 | return tf.to_float(image)/32767.5-1.0
36 | else:
37 | return tf.to_float(image)
38 |
39 | def unnormalize_image(image, maxval=255.0):
40 | # restore image to [0.0, maxval]
41 | return (image+1.0)*maxval*0.5
42 |
43 | def saturate_image(image, dtype=tf.uint8):
44 | return tf.saturate_cast(image, dtype)
45 |
46 | def convert_to_rgb(image, channels=3):
47 | return tf.tile(image, [1,1,1,channels])
48 |
49 |
50 | ########################### masks ###########################
51 |
52 | def extract_boolean_mask(image):
53 | """
54 | input:
55 | image: n x H x W x C : images with value range [-1.0, 1.0] in each channel
56 | output:
57 | mask: n x H x W x 1 : boolean mask (depth channel value < 0.9)
58 | """
59 |
60 | depth = tf.slice(image, [0,0,0,3], [-1,-1,-1,1])
61 | shape = depth.get_shape()
62 | mask = tf.where(tf.greater(depth, 0.9),
63 | tf.constant(False, dtype=tf.bool, shape=shape),
64 | tf.constant(True, dtype=tf.bool, shape=shape))
65 | return mask
66 |
67 | def convert_to_real_mask(bool_mask):
68 | """
69 | input:
70 | bool_mask: boolean mask image
71 | output:
72 | real_mask: real number mask image (-1.0: false, 1.0: true)
73 | """
74 |
75 | shape = bool_mask.get_shape()
76 | return tf.where(bool_mask,
77 | tf.constant(1.0, dtype=tf.float32, shape=shape),
78 | tf.constant(-1.0, dtype=tf.float32, shape=shape))
79 |
80 | def convert_to_boolean_mask(real_mask):
81 | """
82 | input:
83 | real_mask: real number mask image (-1.0: false, 1.0: true)
84 | output:
85 | bool_mask: boolean mask image
86 | """
87 | shape = real_mask.get_shape()
88 | return tf.where(tf.greater(real_mask, 0.0),
89 | tf.constant(True, dtype=tf.bool, shape=shape),
90 | tf.constant(False, dtype=tf.bool, shape=shape))
91 |
92 | def apply_mask(content, mask):
93 | """
94 | input:
95 | content: n x H x W x C : image content
96 | mask: n x H x W x 1 : image mask (>0: true)
97 | output:
98 | output: use content value if mask is true; 1.0 otherwise
99 | """
100 | channel = content.get_shape()[3].value
101 | if channel > 1:
102 | mask = tf.tile(mask, [1,1,1,channel])
103 | return tf.where(tf.greater(mask, 0.0), content, tf.ones_like(content))
104 |
105 |
106 | ########################### filter ###########################
107 |
108 | def get_sobel_filter():
109 |
110 | # 3x3 sobel filter
111 | filter_v = tf.convert_to_tensor(np.array([ \
112 | [-1.0, 0.0, 1.0],
113 | [-2.0, 0.0, 2.0],
114 | [-1.0, 0.0, 1.0]]), dtype=tf.float32)
115 | filter_h = tf.convert_to_tensor(np.array([ \
116 | [ 1.0, 2.0, 1.0],
117 | [ 0.0, 0.0, 0.0],
118 | [-1.0, -2.0, -1.0]]), dtype=tf.float32)
119 | return filter_v, filter_h
120 |
121 | def get_dog_filter(kernel_size):
122 |
123 | # derivative of gaussian filter
124 | kernel_point = np.zeros((kernel_size, kernel_size))
125 | kernel_point[kernel_size//2,kernel_size//2] = 1
126 | kernel_v = ndimage.filters.gaussian_filter(kernel_point, sigma=kernel_size//2, order=[0,1]) * (kernel_size*kernel_size)
127 | kernel_h = kernel_v.T
128 | filter_v = tf.constant(kernel_v, dtype=tf.float32)
129 | filter_h = tf.constant(kernel_h, dtype=tf.float32)
130 | filter_v = tf.expand_dims(tf.expand_dims(filter_v, -1), -1)
131 | filter_h = tf.expand_dims(tf.expand_dims(filter_h, -1), -1)
132 | return filter_v, filter_h
133 |
134 | def apply_edge_filter(images):
135 | """
136 | input:
137 | images: n x H x W x C input images
138 | output:
139 | outputs: n x H x W x 1 output edge images
140 | """
141 |
142 | if images.get_shape()[3].value == 1:
143 | gray_images = images
144 | else:
145 | gray_images = tf.image.rgb_to_grayscale(images)
146 |
147 | if not hasattr(apply_edge_filter, "filter"):
148 | apply_edge_filter.filter = get_dog_filter(15)
149 |
150 | edge_v = tf.nn.conv2d(gray_images, filter=apply_edge_filter.filter[0], strides=[1,1,1,1], padding='SAME')
151 | edge_h = tf.nn.conv2d(gray_images, filter=apply_edge_filter.filter[1], strides=[1,1,1,1], padding='SAME')
152 | outputs = tf.square(edge_v) + tf.square(edge_h)
153 |
154 | return outputs
155 |
156 |
157 | ########################### encoding ###########################
158 |
159 | def encode_batch_images(batch):
160 | """
161 | input:
162 | batch: n x H x W x C input images batch
163 | output:
164 | packed: n x String output PNG-encoded strings
165 | """
166 | # output:
167 | unpacked = tf.unstack(batch)
168 | num = len(unpacked)
169 | encoded = [None] * num
170 | for k in range(num):
171 | encoded[k] = tf.image.encode_png(unpacked[k])
172 | return tf.stack(encoded)
173 |
174 | def encode_raw_batch_images(batch):
175 | """
176 | input:
177 | batch: n x H x W x C input raw images batch
178 | output:
179 | packed: n x String output PNG-encoded strings
180 | """
181 | return encode_batch_images(saturate_image(unnormalize_image(batch)))
182 |
183 | def write_image(name, image):
184 | """
185 | input:
186 | name: String file name
187 | image: String PNG-encoded string
188 | """
189 | path = os.path.dirname(name)
190 | if not os.path.exists(path):
191 | os.makedirs(path)
192 | file = open(name, 'wb')
193 | file.write(image)
194 | file.close()
195 |
196 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is part of the Sketch Modeling project.
3 |
4 | Copyright (c) 2017
5 | -Zhaoliang Lun (author of the code) / UMass-Amherst
6 |
7 | This is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | This software is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with this software. If not, see .
19 | """
20 |
21 |
22 | import tensorflow as tf
23 |
24 | import time
25 | import os
26 |
27 | import data
28 | import monnet as mn
29 | import view as vw
30 |
31 | FLAGS = tf.app.flags.FLAGS
32 |
33 | tf.app.flags.DEFINE_boolean('train', False,
34 | """Flag for training routine.""")
35 | tf.app.flags.DEFINE_boolean('test', False,
36 | """Flag for testing routine.""")
37 | tf.app.flags.DEFINE_boolean('encode', False,
38 | """Flag for encoding routine.""")
39 | tf.app.flags.DEFINE_boolean('predict_normal', True,
40 | """Flag for predicting normal.""")
41 | tf.app.flags.DEFINE_boolean('continuous_view', False,
42 | """Flag for using continuous view architecture.""")
43 | tf.app.flags.DEFINE_boolean('no_adversarial', False,
44 | """Flag for adversarial loss term.""")
45 | tf.app.flags.DEFINE_integer('batch_size', 2,
46 | """Number of images to process in a batch.""")
47 | tf.app.flags.DEFINE_integer('image_size', 256,
48 | """Size of images to be learned.""")
49 | tf.app.flags.DEFINE_integer('sketch_variations', 4,
50 | """Number of variations on input source.""")
51 | tf.app.flags.DEFINE_string('sketch_views', 'F',
52 | """Views used in sketch input ( [F]ront / [T]op / [S]ide )""")
53 | tf.app.flags.DEFINE_float('max_epochs', 100.0,
54 | """Maximum epochs for optimization.""")
55 | tf.app.flags.DEFINE_float('gpu_fraction', 0.9,
56 | """Upper-bound fraction of GPU memory usage.""")
57 | tf.app.flags.DEFINE_string('data_dir', '/vol/research/zy/dataSets/shapeMVD/Chair/',
58 | """Directory containing training/testing images.""")
59 | tf.app.flags.DEFINE_string('sketch_dir', '/vol/research/ycres/zy/dataSets/occ/ShapeNet/',
60 | """Directory containing training/testing images.""")
61 | tf.app.flags.DEFINE_string('sketch_set', '/naive_mad',
62 | """Directory containing training/testing images.""")
63 | tf.app.flags.DEFINE_string('train_dir', '/vol/research/zyres/3dv/baselines/SketchModeling/Network/Checkpoint/',
64 | """Directory where to write training logs.""")
65 | tf.app.flags.DEFINE_string('test_dir', '/vol/research/zyres/3dv/baselines/SketchModeling/Network/output/sty_mad1/',
66 | """Directory where to write testing logs.""")
67 | tf.app.flags.DEFINE_string('check_dir', '/vol/research/zyres/3dv/baselines/SketchModeling/Network/output/sty_mad/',
68 | """Directory where to write testing logs.""")
69 | tf.app.flags.DEFINE_string('encode_dir', './../../../../Data/CharacterDraw/encode/',
70 | """Directory where to write encoding logs.""")
71 | tf.app.flags.DEFINE_string('view_file', 'view.off',
72 | """File with view points information.""")
73 |
74 | def main(argv=None):
75 |
76 | print('start running...')
77 | start_time = time.time()
78 |
79 | ############################################ build graph ############################################
80 |
81 | monnet = mn.MonNet(FLAGS)
82 |
83 | if int(FLAGS.train) + int(FLAGS.test) + int(FLAGS.encode) != 1:
84 | print('please specify \'train\' or \'test\' or \'encode\'')
85 | return
86 |
87 | views = vw.Views(os.path.join(FLAGS.data_dir, 'view', FLAGS.view_file))
88 |
89 | if FLAGS.train:
90 | train_names, train_sources, train_targets, train_masks, train_angles, num_train_shapes = data.load_train_data(FLAGS, views)
91 | valid_names, valid_sources, valid_targets, valid_masks, valid_angles, num_valid_shapes = data.load_validate_data(FLAGS, views)
92 |
93 | with tf.variable_scope("monnet") as scope:
94 | monnet.build_network(\
95 | names=train_names,
96 | sources=train_sources,
97 | targets=train_targets,
98 | masks=train_masks,
99 | angles=train_angles,
100 | views=views,
101 | is_training=True)
102 | scope.reuse_variables() # sharing weights
103 | monnet.build_network(\
104 | names=valid_names,
105 | sources=valid_sources,
106 | targets=valid_targets,
107 | masks=valid_masks,
108 | angles=valid_angles,
109 | views=views,
110 | is_validation=True)
111 | elif FLAGS.test:
112 | test_names, test_sources, test_targets, test_masks, test_angles, num_test_shapes = data.load_test_data(FLAGS, views)
113 |
114 | with tf.variable_scope("monnet") as scope:
115 | monnet.build_network(\
116 | names=test_names,
117 | sources=test_sources,
118 | targets=test_targets,
119 | masks=test_masks,
120 | angles=test_angles,
121 | views=views,
122 | is_testing=True)
123 | elif FLAGS.encode:
124 | encode_names, encode_sources, encode_targets, encode_masks, encode_angles, num_encode_shapes = data.load_encode_data(FLAGS, views)
125 |
126 | with tf.variable_scope("monnet") as scope:
127 | monnet.build_network(\
128 | names=encode_names,
129 | sources=encode_sources,
130 | targets=encode_targets,
131 | masks=encode_masks,
132 | angles=encode_angles,
133 | views=views,
134 | is_encoding=True)
135 |
136 |
137 | ############################################ compute graph ############################################
138 |
139 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_fraction)
140 |
141 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
142 | log_device_placement=False,
143 | allow_soft_placement=True)) as sess:
144 |
145 | if FLAGS.train:
146 | monnet.train(sess, views, num_train_shapes, num_valid_shapes)
147 | elif FLAGS.test:
148 | monnet.test(sess, views, num_test_shapes)
149 | elif FLAGS.encode:
150 | monnet.encode(sess, views, num_encode_shapes)
151 |
152 | sess.close()
153 |
154 | duration = time.time() - start_time
155 | print('total running time: %.1f\n' % duration)
156 |
157 |
158 | if __name__ == '__main__':
159 | tf.app.run()
--------------------------------------------------------------------------------
/reproject.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is part of the Sketch Modeling project.
3 |
4 | Copyright (c) 2017
5 | -Zhaoliang Lun (author of the code) / UMass-Amherst
6 |
7 | This is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | This software is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with this software. If not, see .
19 | """
20 |
21 |
22 | import tensorflow as tf
23 | import numpy as np
24 |
25 | import os
26 |
27 | class ReProj(object):
28 |
29 | def __init__(self):
30 | self.proj = np.identity(4)
31 | self.view = np.identity(4)
32 |
33 | def set_ortho_projection(self, l=-2.5, r=2.5, b=-2.5, t=2.5, n=0.1, f=5.0):
34 | """
35 | args:
36 | l: left
37 | r: right
38 | b: bottom
39 | t: top
40 | n: near
41 | f: far
42 | ref: https://www.opengl.org/sdk/docs/man2/xhtml/glOrtho.xml
43 | """
44 | self.proj = np.array([ \
45 | [2.0/(r-l), 0.0, 0.0, -(r+l)/(r-l)],
46 | [0.0, 2.0/(t-b), 0.0, -(t+b)/(t-b)],
47 | [0.0, 0.0, -2.0/(f-n), -(f+n)/(f-n)],
48 | [0.0, 0.0, 0.0, 1.0 ]])
49 | self.proj_inv = np.linalg.inv(self.proj)
50 |
51 | def set_viewpoint(self, viewpoint):
52 | """
53 | args:
54 | viewpoint: eye position (assuming center at origin, up on Y axis)
55 | ref: http://www.ibm.com/support/knowledgecenter/ssw_aix_53/com.ibm.aix.opengl/doc/openglrf/gluLookAt.htm
56 | """
57 | E = viewpoint
58 | C = np.array([0.0, 0.0, 0.0])
59 | U = np.array([0.0, 1.0, 0.0])
60 | L = C-E;
61 | L = L/np.linalg.norm(L)
62 | S = np.cross(L, U)
63 | if np.linalg.norm(S) == 0:
64 | U = np.array([0.0, 0.0, -1.0])
65 | S = np.cross(L, U)
66 | S = S/np.linalg.norm(S)
67 | Up = np.cross(S, L)
68 | R = np.identity(4)
69 | R[0, 0:3] = S
70 | R[1, 0:3] = Up
71 | R[2, 0:3] = -L
72 | T = np.identity(4)
73 | T[0:3, 3] = -E
74 | self.view = np.dot(R, T)
75 | self.view_inv = np.linalg.inv(self.view)
76 |
77 | def transform(self, depth):
78 | """
79 | input:
80 | depth: H x W depth map with value range [-1, 1]
81 | output:
82 | points: (HxW) x 3 point set
83 | """
84 | H = depth.shape[0]
85 | W = depth.shape[1]
86 | num_points = np.count_nonzero(depth<1.0)
87 | valid_points = [None] * num_points
88 | point_id = 0
89 | for u in range(W):
90 | for v in range(H):
91 | if depth[v,u] < 1.0:
92 | valid_points[point_id] = [(u*2.0+1.0-W)/W, (H-v*2.0-1.0)/H, depth[v,u], 1.0]
93 | point_id += 1
94 | if num_points<=0:
95 | valid_points = np.empty([0,4])
96 | points = np.dot(self.view_inv, np.dot(self.proj_inv, np.array(valid_points).T))[0:3,:].T
97 | return points
98 |
99 | def export_ply(filename, points, normals=None):
100 | """
101 | args:
102 | filename: string file name
103 | points: (HxW) x 3 point set
104 | normals: (HxW) x 3 point set
105 | """
106 | path = os.path.dirname(filename)
107 | if not os.path.exists(path):
108 | os.makedirs(path)
109 | f = open(filename, 'w')
110 | f.write('ply\n')
111 | f.write('format ascii 1.0\n')
112 | f.write('element vertex %d\n' % points.shape[0])
113 | f.write('property float x\n')
114 | f.write('property float y\n')
115 | f.write('property float z\n')
116 | if normals is not None:
117 | f.write('property float nx\n')
118 | f.write('property float ny\n')
119 | f.write('property float nz\n')
120 | f.write('end_header\n')
121 | for k in range(points.shape[0]):
122 | f.write('%f %f %f\n' % (points[k,0], points[k,1], points[k,2]))
123 | if normals is not None:
124 | f.write('%f %f %f\n' % (normals[k,0], normals[k,1], normals[k,2]))
125 | f.close()
126 |
127 | def transform_tensor(predicts, views):
128 | """
129 | input:
130 | predicts : (n*V) x H x W x 4 predicted tensor (in n batches & V views)
131 | views : V x 3 view point positions (numpy array)
132 | output:
133 | points : (n*V) x H x W x 3 re-projected points position tensor
134 | dirs : (n*V) x H x W x 3 re-projected normals direction tensor
135 | """
136 |
137 | shape = predicts.get_shape().as_list()
138 | num_views = views.shape[0]
139 | num_batches = shape[0] / num_views
140 |
141 | # calculate reprojection matrix
142 |
143 | reproj = ReProj()
144 | reproj.set_ortho_projection()
145 |
146 | xform_per_view = [None] * num_views
147 | rotate_per_view = [None] * num_views
148 | for view_id in range(num_views):
149 | reproj.set_viewpoint(views[view_id,:])
150 | xform_per_view[view_id] = tf.constant(np.dot(reproj.view_inv, reproj.proj_inv), dtype=tf.float32) # [4 x 4] * V
151 | rotate_per_view[view_id] = tf.constant(reproj.view_inv, dtype=tf.float32) # [4 x 4] * V
152 |
153 | # separate depth/normal by views
154 |
155 | predicts_per_view = tf.transpose(tf.reshape(predicts, [-1, num_views, shape[1], shape[2], shape[3]]), [1, 0, 2, 3, 4]) # V x n x H x W x 4
156 | depths_per_view = tf.unstack(tf.slice(predicts_per_view, [0,0,0,0,3], [-1,-1,-1,-1,1])) # [n x H x W x 1] * V
157 | normals_per_view = tf.unstack(tf.slice(predicts_per_view, [0,0,0,0,0], [-1,-1,-1,-1,3])) # [n x H x W x 3] * V
158 |
159 | # calculate projected coordinates
160 |
161 | H = shape[1]
162 | W = shape[2]
163 | vec_u = tf.constant([(u*2.0+1.0-W)/W for u in range(W)]) # W
164 | vec_v = tf.constant([(H-v*2.0-1.0)/H for v in range(H)]) # H
165 | mat_u = tf.tile(tf.reshape(vec_u, [1,1,-1,1]), (num_batches,H,1,1)) # n x H x W x 1
166 | mat_v = tf.tile(tf.reshape(vec_v, [1,-1,1,1]), (num_batches,1,W,1)) # n x H x W x 1
167 | mat_w = tf.ones([num_batches, H, W, 1])
168 |
169 | homo_points_per_view = [tf.concat([mat_u, mat_v, mat_d, mat_w], 3) for mat_d in depths_per_view] # [n x H x W x 4] * V
170 | homo_dirs_per_view = [tf.concat([mat_n, mat_w], 3) for mat_n in normals_per_view] # [n x H x W x 4] * V
171 |
172 | # transform points
173 |
174 | points_per_view = [None] * num_views
175 | dirs_per_view = [None] * num_views
176 | for view_id in range(num_views):
177 | xformed = tf.matmul(tf.reshape(homo_points_per_view[view_id], [-1,4]), xform_per_view[view_id], transpose_b=True) # (n*H*W) x 4
178 | rotated = tf.matmul(tf.reshape(homo_dirs_per_view[view_id], [-1,4]), rotate_per_view[view_id], transpose_b=True) # (n*H*W) x 4
179 | points_per_view[view_id] = tf.slice(tf.reshape(xformed, [-1,H,W,4]), [0,0,0,0], [-1,-1,-1,3]) # n x H x W x 3
180 | dirs_per_view[view_id] = tf.slice(tf.reshape(rotated, [-1,H,W,4]), [0,0,0,0], [-1,-1,-1,3]) # n x H x W x 3
181 |
182 | # organize output points
183 |
184 | points = tf.transpose(tf.stack(points_per_view), [1,0,2,3,4]) # n x v x H x W x 3
185 | points = tf.reshape(points, [-1, H, W, 3]) # (n*V) x H x W x 3
186 |
187 | dirs = tf.transpose(tf.stack(dirs_per_view), [1,0,2,3,4]) # n x v x H x W x 3
188 | dirs = tf.reshape(dirs, [-1, H, W, 3]) # (n*V) x H x W x 3
189 |
190 | return points, dirs
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is part of the Sketch Modeling project.
3 |
4 | Copyright (c) 2017
5 | -Zhaoliang Lun (author of the code) / UMass-Amherst
6 |
7 | This is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | This software is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with this software. If not, see .
19 | """
20 |
21 |
22 | import tensorflow as tf
23 | import numpy as np
24 |
25 | import tensorflow.contrib.layers as tf_layers
26 |
27 | import layer
28 | import image
29 |
30 | def generateUNet(images, num_output_views, num_output_channels):
31 | """
32 | input:
33 | images : n x H x W x Ci input images ( 256 x 256 x Ci )
34 | num_output_views : int number of output views
35 | num_output_channels : int number of output image channels
36 | output:
37 | results : (n*m) x H x W x Co output images ( 256 x 256 x Co )
38 | features : n x D output features ( 512 )
39 | """
40 |
41 | ###### encoding ######
42 |
43 | e1 = tf_layers.conv2d(images, num_outputs= 64, kernel_size=4, stride=2, scope='e1', normalizer_fn=None) # 128 x 128 x 64
44 | e2 = tf_layers.conv2d( e1, num_outputs=128, kernel_size=4, stride=2, scope='e2') # 64 x 64 x 128
45 | e3 = tf_layers.conv2d( e2, num_outputs=256, kernel_size=4, stride=2, scope='e3') # 32 x 32 x 256
46 | e4 = tf_layers.conv2d( e3, num_outputs=256, kernel_size=4, stride=2, scope='e4') # 16 x 16 x 512
47 | e5 = tf_layers.conv2d( e4, num_outputs=256, kernel_size=4, stride=2, scope='e5') # 8 x 8 x 512
48 | e6 = tf_layers.conv2d( e5, num_outputs=512, kernel_size=4, stride=2, scope='e6') # 4 x 4 x 512
49 | e7 = tf_layers.conv2d( e6, num_outputs=512, kernel_size=4, stride=2, scope='e7') # 2 x 2 x 512
50 |
51 | num_batches = images.get_shape()[0].value
52 | features = tf.reshape(e7, [num_batches, -1]) # 2048
53 |
54 | ###### decoding ######
55 |
56 | nc = num_output_channels
57 | rpv = [None] * num_output_views # results per view
58 | for view in range(num_output_views):
59 |
60 | with tf.variable_scope('decoder_%d' % view):
61 | d6 = tf_layers.dropout(layer.unconv_layer( e7, num_outputs=512, kernel_size=4, stride=2, scope='d6')) # 4 x 4 x 512
62 | d5 = tf_layers.dropout(layer.unconv_layer(tf.concat([d6, e6], 3), num_outputs=256, kernel_size=4, stride=2, scope='d5')) # 8 x 8 x 512
63 | d4 = layer.unconv_layer(tf.concat([d5, e5], 3), num_outputs=256, kernel_size=4, stride=2, scope='d4') # 16 x 16 x 512
64 | d3 = layer.unconv_layer(tf.concat([d4, e4], 3), num_outputs=256, kernel_size=4, stride=2, scope='d3') # 32 x 32 x 256
65 | d2 = layer.unconv_layer(tf.concat([d3, e3], 3), num_outputs=128, kernel_size=4, stride=2, scope='d2') # 64 x 64 x 128
66 | d1 = layer.unconv_layer(tf.concat([d2, e2], 3), num_outputs= 64, kernel_size=4, stride=2, scope='d1') # 128 x 128 x 64
67 | rpv[view] = layer.unconv_layer(tf.concat([d1, e1], 3), num_outputs= nc, kernel_size=4, stride=2, scope='re', normalizer_fn=None, activation_fn=tf.tanh)
68 |
69 | height = images.get_shape()[1].value
70 | width = images.get_shape()[2].value
71 | results = tf.reshape(tf.transpose(tf.stack(rpv), [1,0,2,3,4]), [-1, height, width, nc])
72 |
73 | return results, features
74 |
75 | def generateCNet(images, angles, num_output_channels):
76 | """
77 | input:
78 | images : n x H x W x Ci input images ( 256 x 256 x Ci )
79 | angles : n x 4 output viewing angle parameters
80 | num_output_channels : int number of output image channels
81 | output:
82 | results : n x H x W x Co output images ( 256 x 256 x Co )
83 | features : n x D output features ( 512 )
84 | """
85 |
86 | ###### encoding ######
87 |
88 | e1 = tf_layers.conv2d(images, num_outputs= 64, kernel_size=4, stride=2, scope='e1', normalizer_fn=None) # 128 x 128 x 64
89 | e2 = tf_layers.conv2d( e1, num_outputs=128, kernel_size=4, stride=2, scope='e2') # 64 x 64 x 128
90 | e3 = tf_layers.conv2d( e2, num_outputs=256, kernel_size=4, stride=2, scope='e3') # 32 x 32 x 256
91 | e4 = tf_layers.conv2d( e3, num_outputs=512, kernel_size=4, stride=2, scope='e4') # 16 x 16 x 512
92 | e5 = tf_layers.conv2d( e4, num_outputs=512, kernel_size=4, stride=2, scope='e5') # 8 x 8 x 512
93 | e6 = tf_layers.conv2d( e5, num_outputs=512, kernel_size=4, stride=2, scope='e6') # 4 x 4 x 512
94 | e7 = tf_layers.conv2d( e6, num_outputs=512, kernel_size=4, stride=2, scope='e7') # 2 x 2 x 512
95 |
96 | num_batches = images.get_shape()[0].value
97 | ifeat = tf.reshape(e7, [num_batches, -1]) # 2048
98 | ifeat = tf_layers.fully_connected(ifeat, 2048, scope='ifc') # 2048
99 | features = ifeat
100 |
101 | vfeat = tf_layers.stack(
102 | angles,
103 | tf_layers.fully_connected,
104 | [64, # 64
105 | 64, # 64
106 | 64], # 64
107 | scope='vfc')
108 |
109 | ###### decoding ######
110 |
111 | nc = num_output_channels
112 | mp = 1 # multiplier for filter size (should be something close to the square root of number of output views)
113 |
114 | feat = tf_layers.stack(
115 | tf.concat([ifeat, vfeat], 1),
116 | tf_layers.fully_connected,
117 | [1024*mp, # 1024*mp
118 | 1024*mp, # 1024*mp
119 | 2048*mp], # 2048*mp
120 | scope='fc')
121 | feat = tf.reshape(feat, [-1, 2, 2, 512*mp]) # 2 x 2 x 512*mp
122 |
123 | #d6 = layer.unconv_layer( feat, num_outputs=512*mp, kernel_size=4, stride=2, scope='d6') # 4 x 4 x 512*mp
124 | #d5 = layer.unconv_layer(tf.concat([d6, e6], 3), num_outputs=512*mp, kernel_size=4, stride=2, scope='d5') # 8 x 8 x 512*mp
125 | #d4 = layer.unconv_layer(tf.concat([d5, e5], 3), num_outputs=512*mp, kernel_size=4, stride=2, scope='d4') # 16 x 16 x 512*mp
126 | #d3 = layer.unconv_layer(tf.concat([d4, e4], 3), num_outputs=256*mp, kernel_size=4, stride=2, scope='d3') # 32 x 32 x 256*mp
127 | #d2 = layer.unconv_layer(tf.concat([d3, e3], 3), num_outputs=128*mp, kernel_size=4, stride=2, scope='d2') # 64 x 64 x 128*mp
128 | #d1 = layer.unconv_layer(tf.concat([d2, e2], 3), num_outputs= 64*mp, kernel_size=4, stride=2, scope='d1') # 128 x 128 x 64*mp
129 | #results = layer.unconv_layer(tf.concat([d1, e1], 3), num_outputs= nc, kernel_size=4, stride=2, scope='re', normalizer_fn=None, activation_fn=tf.tanh)
130 |
131 | d6 = layer.unconv_layer(feat, num_outputs=512*mp, kernel_size=4, stride=2, scope='d6') # 4 x 4 x 512*mp
132 | d5 = layer.unconv_layer(d6, num_outputs=512*mp, kernel_size=4, stride=2, scope='d5') # 8 x 8 x 512*mp
133 | d4 = layer.unconv_layer(d5, num_outputs=512*mp, kernel_size=4, stride=2, scope='d4') # 16 x 16 x 512*mp
134 | d3 = layer.unconv_layer(d4, num_outputs=256*mp, kernel_size=4, stride=2, scope='d3') # 32 x 32 x 256*mp
135 | d2 = layer.unconv_layer(d3, num_outputs=128*mp, kernel_size=4, stride=2, scope='d2') # 64 x 64 x 128*mp
136 | d1 = layer.unconv_layer(d2, num_outputs= 64*mp, kernel_size=4, stride=2, scope='d1') # 128 x 128 x 64*mp
137 | results = layer.unconv_layer(d1, num_outputs= nc, kernel_size=4, stride=2, scope='re', normalizer_fn=None, activation_fn=tf.tanh)
138 |
139 | return results, features
140 |
141 | def discriminate(data):
142 | """
143 | intput:
144 | data : n x H x W x C data to be discriminated ( 256 x 256 x C )
145 | output:
146 | probs : n probabilities being real
147 | """
148 |
149 | d1 = tf_layers.conv2d(data, num_outputs= 64, kernel_size=4, stride=2, scope='d1', normalizer_fn=None) # 128 x 128 x 64
150 | d2 = tf_layers.conv2d(d1, num_outputs=128, kernel_size=4, stride=2, scope='d2') # 64 x 64 x 128
151 | d3 = tf_layers.conv2d(d2, num_outputs=256, kernel_size=4, stride=2, scope='d3') # 32 x 32 x 256
152 | d4 = tf_layers.conv2d(d3, num_outputs=512, kernel_size=4, stride=2, scope='d4') # 16 x 16 x 512
153 | d5 = tf_layers.conv2d(d4, num_outputs=512, kernel_size=4, stride=2, scope='d5') # 8 x 8 x 512
154 | d6 = tf_layers.conv2d(d5, num_outputs=512, kernel_size=4, stride=2, scope='d6') # 4 x 4 x 512
155 | d7 = tf_layers.conv2d(d6, num_outputs=512, kernel_size=4, stride=2, scope='d7') # 2 x 2 x 512
156 |
157 | feature = tf.reshape(d7, [-1, 2048]) # 2048
158 | probs = tf_layers.fully_connected(feature, 1, scope='fc', normalizer_fn=None, activation_fn=tf.sigmoid) # 1
159 | probs = tf.reshape(probs, [-1])
160 |
161 | return probs
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is part of the Sketch Modeling project.
3 |
4 | Copyright (c) 2017
5 | -Zhaoliang Lun (author of the code) / UMass-Amherst
6 |
7 | This is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | This software is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with this software. If not, see .
19 | """
20 |
21 |
22 | import tensorflow as tf
23 | import numpy as np
24 |
25 | import image
26 | import reproject as rp
27 | import view as vw
28 |
29 | def compute_depth_loss(predicts, targets, mask, normalized=True):
30 | """
31 | input:
32 | predicts : n x H x W x 1 predicted depths
33 | targets : n x H x W x 1 ground-truth depths
34 | mask : n x H x W x 1 boolean mask
35 | normalized : boolean whether output loss should be normalized by pixel number
36 | output:
37 | loss : scalar loss value
38 | """
39 |
40 | num_batches = predicts.get_shape()[0].value
41 | num_channels = predicts.get_shape()[3].value
42 |
43 | diff = tf.abs(predicts-targets) # L-1 loss
44 | # diff = tf.square(predicts-targets) # L-2 loss
45 | diff = tf.boolean_mask(diff, tf.squeeze(mask, [3]))
46 | if normalized:
47 | depth_loss = tf.reduce_mean(diff) * (num_batches*num_channels)
48 | else:
49 | depth_loss = tf.reduce_sum(diff)
50 |
51 | return depth_loss
52 |
53 | def compute_normal_loss(predicts, targets, mask, normalized=True):
54 | """
55 | input:
56 | predicts : n x H x W x 3 predicted normals
57 | targets : n x H x W x 3 ground-truth normals
58 | mask : n x H x W x 1 boolean mask
59 | normalized : boolean whether output loss should be normalized by pixel number
60 | output:
61 | loss : scalar loss value
62 | """
63 |
64 | num_batches = predicts.get_shape()[0].value
65 | num_channels = predicts.get_shape()[3].value
66 |
67 | # with unit length 1-n_1*n_2 = 0.5*||n_1-n_2||^2
68 | diff = tf.square(predicts-targets)
69 | diff = tf.boolean_mask(diff, tf.squeeze(mask, [3]))
70 | if normalized:
71 | normal_loss = tf.reduce_mean(diff) * (num_batches*num_channels)
72 | else:
73 | normal_loss = tf.reduce_sum(diff)
74 |
75 | return normal_loss
76 |
77 | def compute_mask_loss(predicts, targets, normalized=True):
78 | """
79 | input:
80 | predicts : n x H x W x C generated masks (-1: false, 1: true)
81 | targets : n x H x W x C ground-truth masks (-1: false, 1: true)
82 | normalized : boolean whether output loss should be normalized by pixel number
83 | output:
84 | loss : scalar loss value
85 | """
86 |
87 | p = predicts * 0.5 + 0.5 # convert to probability
88 | z = targets * 0.5 + 0.5
89 | # L = -z*log(p)-(1-z)*log(1-p)
90 | mask_loss = tf.reduce_sum(-tf.multiply(tf.log(tf.maximum(1e-6, p)), z)-tf.multiply(tf.log(tf.maximum(1e-6, 1-p)), 1-z))
91 |
92 | if normalized:
93 | mask_shape = predicts.get_shape().as_list()
94 | num_pixels = np.prod(mask_shape[1:])
95 | mask_loss /= num_pixels
96 |
97 | return mask_loss
98 |
99 | def compute_pixel_loss(predicts, targets, normalized=True):
100 | """
101 | input:
102 | predicts : n x H x W x C predicted images
103 | targets : n x H x W x C ground-truth images
104 | normalized : boolean whether output loss should be normalized by pixel number
105 | output:
106 | loss : scalar loss value
107 | """
108 |
109 | num_batches = predicts.get_shape()[0].value
110 | num_channels = predicts.get_shape()[3].value
111 |
112 | diff = tf.abs(predicts-targets) # L-1 loss
113 | # diff = tf.square(predicts-targets) # L-2 loss
114 | if normalized:
115 | pixel_loss = tf.reduce_mean(diff) * (num_batches*num_channels)
116 | else:
117 | pixel_loss = tf.reduce_sum(diff)
118 |
119 | return pixel_loss
120 |
121 | def compute_consist_loss(contents, normalized=True):
122 | """
123 | input:
124 | contents : n x H x W x 4 normal/depth maps (nx, ny, nz, d)
125 | normalized : boolean whether output loss should be normalized by pixel number
126 | output:
127 | loss : scalar loss value
128 | """
129 |
130 | # Lx = | kappa * nx + dZdx * nz |
131 | # Ly = | kappa * ny + dZdy * nz |
132 |
133 | shape = contents.get_shape().as_list()
134 | num_batches = shape[0]
135 | H = shape[1]
136 | W = shape[2]
137 | kappaX = 5.0 / H # NOTE: view radius = 2.5
138 | kappaY = 5.0 / W
139 |
140 | filter_x = tf.convert_to_tensor(np.array([\
141 | [1.0, 0.0, -1.0],
142 | [4.0, 0.0, -4.0],
143 | [1.0, 0.0, -1.0]]), dtype=tf.float32)
144 | filter_y = tf.convert_to_tensor(np.array([\
145 | [-1.0, -4.0, -1.0],
146 | [0.0, 0.0, 0.0],
147 | [1.0, 4.0, 1.0]]), dtype=tf.float32)
148 | filter_x = tf.expand_dims(tf.expand_dims(filter_x, -1), -1)
149 | filter_y = tf.expand_dims(tf.expand_dims(filter_y, -1), -1)
150 |
151 | nx, ny, nz, d = tf.split(contents, 4, axis=3)
152 |
153 | dZdx = tf.nn.conv2d(d, filter=filter_x, strides=[1,1,1,1], padding='SAME')
154 | dZdy = tf.nn.conv2d(d, filter=filter_y, strides=[1,1,1,1], padding='SAME')
155 |
156 | Lx = tf.abs(tf.scalar_mul(kappaX, nx) + tf.multiply(dZdx, nz))
157 | Ly = tf.abs(tf.scalar_mul(kappaY, ny) + tf.multiply(dZdy, nz))
158 |
159 | if normalized:
160 | consist_loss = (tf.reduce_mean(Lx)+tf.reduce_mean(Ly)) * num_batches
161 | else:
162 | consist_loss = tf.reduce_sum(Lx)+tf.reduce_sum(Ly)
163 |
164 | return consist_loss
165 |
166 |
167 | def compute_corres_geom_loss(predicts, corres, views):
168 | """
169 | input:
170 | predicts : (n*v) x H x W x 4 predicted images
171 | corres : n x G x M x v correspondence point indices (G groups of M correspondences across v span views)
172 | views : vw.Views view points data
173 | output:
174 | loss : scalar loss value
175 | """
176 |
177 | if views.num_edges == 0:
178 | return 0
179 |
180 | position_factor = 1.0
181 | direction_factor = 1.0
182 |
183 | shape = predicts.get_shape().as_list()
184 | H = shape[1]
185 | W = shape[2]
186 | num_batches = shape[0] / views.num_views
187 | num_samples = corres.get_shape()[2].value
188 |
189 | points, dirs = rp.transform_tensor(predicts, views.views) # (n*V) x H x W x 3
190 |
191 | batch_points = tf.unpack(tf.reshape(points, [-1,views.num_views,H,W,3])) # [V x H x W x 3] * n
192 | batch_dirs = tf.unpack(tf.reshape(dirs, [-1,views.num_views,H,W,3])) # [V x H x W x 3] * n
193 | batch_corres = tf.unpack(corres) # [G x M x v] * n
194 |
195 | batch_losses = [None] * num_batches
196 | for batch_id in range(num_batches):
197 | all_points = tf.reshape(batch_points[batch_id], [-1,3]) # (V*H*W) x 3
198 | all_dirs = tf.reshape(batch_dirs[batch_id], [-1,3]) # (V*H*W) x 3
199 | all_corres = tf.reshape(batch_corres[batch_id], [-1]) # (G*M*v)
200 | slice_points = tf.reshape(tf.gather(all_points, all_corres), [views.num_edges,-1,views.edge_size,3]) # G x M x v x 3
201 | slice_dirs = tf.reshape(tf.gather(all_dirs, all_corres), [views.num_edges,num_samples,views.edge_size,3]) # G x M x v x 3
202 |
203 | # compute position loss as variance of reprojected point positions across nearby views
204 | normalized_points = slice_points - tf.tile(tf.reduce_mean(slice_points, reduction_indices=2, keep_dims=True), [1,1,views.edge_size,1]) # G x M x v x 3
205 | position_loss = tf.reduce_mean(tf.multiply(normalized_points, normalized_points))*3.0
206 |
207 | # compute direction loss as mean(1-dot(n,n)) for all pairs of reprojected directions across nearby views
208 | lensq_dirs = tf.maximum(tf.reduce_sum(tf.multiply(slice_dirs, slice_dirs), reduction_indices=3, keep_dims=True), 1e-3)
209 | normalized_dirs = tf.multiply(slice_dirs, tf.tile(tf.rsqrt(lensq_dirs), (1,1,1,3)))
210 | transposed = tf.reshape(tf.transpose(normalized_dirs, [2,0,1,3]), [views.edge_size, -1]) # V x (G*M*3)
211 | direction_loss = 1.0 - tf.reduce_mean(tf.matmul(transposed, transposed, transpose_b=True))*(1.0/(views.num_edges*num_samples))
212 |
213 | batch_losses[batch_id] = position_factor*position_loss + direction_factor*direction_loss
214 |
215 | loss = tf.reduce_sum(tf.stack(batch_losses))
216 |
217 | return loss
218 |
219 | def compute_corres_mask_loss(predicts, corres, views):
220 | """
221 | input:
222 | predicts : (n*v) x H x W x 1 predicted masks
223 | corres : n x G x M x v correspondence point indices (G groups of M correspondences across v span views)
224 | views : vw.Views view points data
225 | output:
226 | loss : scalar loss value
227 | """
228 |
229 | if views.num_edges == 0:
230 | return 0
231 |
232 | shape = predicts.get_shape().as_list()
233 | H = shape[1]
234 | W = shape[2]
235 | num_batches = shape[0] / views.num_views
236 | num_samples = corres.get_shape()[2].value
237 |
238 | probs = predicts*0.5+0.5 # [-1,1] => [0,1]
239 |
240 | batch_probs = tf.unpack(tf.reshape(probs, [-1,views.num_views,H,W,1])) # [V x H x W x 1] * n
241 | batch_corres = tf.unpack(corres) # [G x M x v] * n
242 |
243 | batch_losses = [None] * num_batches
244 | for batch_id in range(num_batches):
245 | all_probs = tf.reshape(batch_probs[batch_id], [-1,1]) # (V*H*W) x 1
246 | all_corres = tf.reshape(batch_corres[batch_id], [-1]) # (G*M*v)
247 | slice_probs = tf.reshape(tf.gather(all_probs, all_corres), [views.num_edges,-1,views.edge_size,1]) # G x M x v x 1
248 |
249 | # compute mask loss as Jensen-Shannon divergence of predicted mask probabilities across nearby views
250 | mask_loss = tf.reduce_mean( compute_entropy(tf.reduce_mean(slice_probs, reduction_indices=1)) - tf.reduce_mean(compute_entropy(slice_probs), reduction_indices=1) )
251 |
252 | batch_losses[batch_id] = mask_loss
253 |
254 | loss = tf.reduce_sum(tf.pack(batch_losses))
255 |
256 | return loss
257 |
258 | def compute_entropy(tensor):
259 | """
260 | input:
261 | tensor : any shape tensor
262 | output:
263 | entropy : tensor having the same shape with input tensor
264 | """
265 |
266 | entropy = - tf.multiply(tensor, tf.log(tensor+1e-6)) - tf.multiply(1.0-tensor, tf.log(1.0-tensor+1e-6))
267 | return entropy
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is part of the Sketch Modeling project.
3 |
4 | Copyright (c) 2017
5 | -Zhaoliang Lun (author of the code) / UMass-Amherst
6 |
7 | This is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | This software is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with this software. If not, see .
19 | """
20 |
21 |
22 | import tensorflow as tf
23 | import numpy as np
24 |
25 | import os
26 | import math
27 |
28 | import image
29 | import view as vw
30 | from pathlib import Path
31 |
32 | NUM_CORRESPONDENCES = 1024
33 |
34 | def load_data(config, views, shape_list, shuffle=True, batch_size=-1):
35 | """
36 | input:
37 | config tf.app.flags command line arguments
38 | views vw.View view points information
39 | shape_list list of string input shape name list
40 | shuffle bool whether input shape list should be shuffled
41 | output:
42 | name_batch n x string shape names
43 | source_batch n x H x W x Ci source images
44 | target_batch (n*m) x H x W x Co target images in m views
45 | mask_batch (n*m) x H x W x 1 target boolean masks in m views
46 | angle_batch (n*m) x 4 target viewing angle params in m views
47 | num_shapes int number of loaded shapes
48 | """
49 |
50 | if batch_size==-1:
51 | batch_size = config.batch_size
52 |
53 | # handle affix
54 |
55 | num_source_views = len(config.sketch_views)
56 | # source_prefix_list = ['sketch/' for view in range(num_source_views)]
57 | # source_interfix_list = ['/sketch-%c' % v for v in config.sketch_views]
58 | # if config.test:
59 | # sketch_variation = '0'
60 | # else:
61 | # sketch_variation_queue = tf.train.string_input_producer(['%d' % v for v in range(config.sketch_variations)], shuffle=True)
62 | # sketch_variation = sketch_variation_queue.dequeue()
63 | # source_suffix_list = ['-'+sketch_variation+'.png' for view in range(num_source_views)]
64 |
65 | num_dnfs_views = max(2, len(config.sketch_views))
66 | dnfs_prefix_list = ['dnfs/' for view in range(num_dnfs_views)]
67 | dnfs_interfix_list = ['/dnfs-%d' % config.image_size for view in range(num_dnfs_views)]
68 | dnfs_suffix_list = ['-%d.png' % view for view in range(num_dnfs_views)]
69 |
70 | num_dn_views = 12
71 | dn_prefix_list = ['dn/' for view in range(num_dn_views)]
72 | dn_interfix_list = ['/dn-%d' % config.image_size for view in range(num_dn_views)]
73 | dn_suffix_list = ['-%d.png' % view for view in range(num_dn_views)]
74 |
75 | num_target_views = num_dnfs_views + num_dn_views
76 | target_prefix_list = dnfs_prefix_list + dn_prefix_list
77 | target_interfix_list = dnfs_interfix_list + dn_interfix_list
78 | target_suffix_list = dnfs_suffix_list + dn_suffix_list
79 | num_target_views = views.num_views
80 |
81 | # build input queue
82 |
83 | if config.continuous_view and config.test:
84 | shape_list_queue = tf.train.input_producer([name for name in shape_list for view in range(num_target_views)], shuffle=False)
85 | else:
86 | shape_list_queue = tf.train.input_producer(shape_list, shuffle=shuffle)
87 |
88 | # load data from queue
89 |
90 | shape_name = shape_list_queue.dequeue()
91 | extension = 'jpg'
92 | # import pudb; pu.db
93 | image_dir = config.sketch_dir+shape_name+config.sketch_set
94 | if 'human' in config.sketch_set:
95 | file_glob = image_dir + '/' + '*.' + extension
96 | source_files_list = tf.matching_files(file_glob)
97 | else:
98 | file_glob_base = image_dir + '/base/' + '*.' + extension
99 | file_glob_bias = image_dir + '/bias/' + '*.' + extension
100 | # import pudb; pu.db
101 | source_files_list_base = tf.matching_files(file_glob_base)
102 | source_files_list_bias = tf.matching_files(file_glob_bias)
103 | # print('############################################')
104 | # print(source_files_list_base)
105 | source_files_list = tf.concat([source_files_list_base,source_files_list_bias], 0)
106 |
107 | source_file_queue = tf.train.string_input_producer(source_files_list, shuffle=True)
108 | source_file = source_file_queue.dequeue()
109 | # source_files = [config.data_dir+shape_name+source_prefix_list_base[view]+source_interfix_list[view]+source_suffix_list[view] for view in range(num_source_views)]
110 | # source_files = [config.data_dir+source_prefix_list[view]+shape_name+source_interfix_list[view]+source_suffix_list[view] for view in range(num_source_views)]
111 | if not config.continuous_view:
112 | target_files = [config.data_dir+target_prefix_list[view]+shape_name+target_interfix_list[view]+target_suffix_list[view] for view in range(num_target_views)]
113 | target_angles = tf.zeros([num_target_views, 4])
114 | else:
115 | angle_list = [vw.view2angle(view) for view in views.views]
116 | view_list_queue = tf.train.slice_input_producer([angle_list, target_prefix_list, target_interfix_list, target_suffix_list], shuffle=(not config.test))
117 | target_files = [config.data_dir+view_list_queue[1]+shape_name+view_list_queue[2]+view_list_queue[3]] # only one single image
118 | target_angles = [view_list_queue[0]]
119 |
120 | # decode source imagess
121 | # source_images = [tf.image.decode_png(tf.read_file(file), channels=1, dtype=tf.uint8) for file in source_files]
122 | # source_image = tf.concat(source_images, 2) # put multi-view images into different channels
123 | source_images = [tf.image.resize_images(tf.image.decode_png(tf.read_file(source_file), channels=1, dtype=tf.uint8), [256, 256], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)]
124 | source_image = tf.concat(source_images, 2) # put multi-view images into different channels
125 | source_image = image.normalize_image(tf.slice(source_image, [0,0,0], [config.image_size, config.image_size, -1])) # just do a useless slicing to establish size
126 | source_image = tf.concat([source_image, tf.image.flip_left_right(source_image)], 2) # HACK: add horizontally flipped image as input
127 |
128 | # decode target images
129 |
130 | if not config.test:
131 | target_images = tf.stack([tf.image.decode_png(tf.read_file(file), channels=4, dtype=tf.uint16) for file in target_files])
132 | target_images = image.normalize_image(tf.slice(target_images, [0,0,0,0], [-1,config.image_size, config.image_size, -1]))
133 | else:
134 | target_images = tf.ones([len(target_files), config.image_size, config.image_size, 4]) # dummy target for testing
135 | target_masks = image.extract_boolean_mask(target_images)
136 |
137 | if config.predict_normal:
138 | # pre-process normal background
139 | target_shape = target_images.get_shape().as_list()
140 | target_background = tf.concat([tf.zeros(target_shape[:-1]+[2]), tf.ones(target_shape[:-1]+[2])], 3) # (0,0,1,1)
141 | target_images = tf.where(tf.tile(target_masks, [1,1,1,target_shape[3]]), target_images, target_background)
142 | else:
143 | # retain depth only
144 | target_images = tf.slice(target_images, [0,0,0,3], [-1,-1,-1,1])
145 |
146 | target_images = tf.concat([target_images, image.convert_to_real_mask(target_masks)], 3)
147 |
148 | # create prefetching tensor
149 |
150 | num_shapes = len(shape_list)
151 | min_queue_examples = max(1, int(num_shapes * 0.01))
152 |
153 | tensor_data = [shape_name, source_image, target_images, target_masks, target_angles]
154 | print('name: ', shape_name)
155 | print('source: ', source_image)
156 | print('target: ', target_images)
157 | print('mask: ', target_masks)
158 | print('angle: ', target_angles)
159 |
160 | if shuffle:
161 | num_preprocess_threads = 12
162 | batch_data = tf.train.shuffle_batch(
163 | tensor_data,
164 | batch_size=batch_size,
165 | num_threads=num_preprocess_threads,
166 | capacity=min_queue_examples + 3 * batch_size,
167 | min_after_dequeue=min_queue_examples)
168 | else:
169 | num_preprocess_threads = 1
170 | batch_data = tf.train.batch(
171 | tensor_data,
172 | batch_size=batch_size,
173 | num_threads=num_preprocess_threads,
174 | capacity=min_queue_examples)
175 |
176 | name_batch = batch_data[0]
177 | source_batch = batch_data[1]
178 | target_batch = batch_data[2]
179 | target_batch = tf.reshape(target_batch, [-1]+target_batch.get_shape().as_list()[2:])
180 | mask_batch = batch_data[3]
181 | mask_batch = tf.reshape(mask_batch, [-1]+mask_batch.get_shape().as_list()[2:])
182 | angle_batch = batch_data[4]
183 | angle_batch = tf.reshape(angle_batch, [-1]+angle_batch.get_shape().as_list()[2:])
184 |
185 | print('*******************************')
186 | print('name: ', name_batch)
187 | print('source: ', source_batch)
188 | print('target: ', target_batch)
189 | print('mask: ', mask_batch)
190 | print('angle: ', angle_batch)
191 |
192 | return name_batch, source_batch, target_batch, mask_batch, angle_batch, num_shapes
193 |
194 | def load_train_data(config, views, batch_size=-1):
195 |
196 | print("Loading training data...")
197 |
198 | shape_list_file = open(os.path.join(config.data_dir, 'train-list.txt'), 'r')
199 | shape_list = shape_list_file.read().splitlines()
200 | shape_list_file.close()
201 |
202 | return load_data(config, views, shape_list, shuffle=True, batch_size=batch_size)
203 |
204 | # def load_test_data(config, views, batch_size=-1):
205 |
206 | # print("Loading testing data...")
207 |
208 | # shape_list_file = open(os.path.join(config.data_dir, 'test-list.txt'), 'r')
209 | # shape_list = shape_list_file.read().splitlines()
210 | # shape_list_file.close()
211 |
212 | # return load_data(config, views, shape_list, shuffle=False, batch_size=batch_size)
213 |
214 | def load_test_data(config, views, batch_size=-1):
215 |
216 | print("Loading testing data...")
217 |
218 | shape_list_file = open(os.path.join(config.data_dir, 'test-list.txt'), 'r')
219 | shape_list = shape_list_file.read().splitlines()
220 | shape_list_file.close()
221 | # import pudb; pu.db
222 | test_path = Path(config.check_dir)/'results'/'03001627'
223 | exists_list = [x for x in test_path.iterdir() if x.is_dir()]
224 | extra_path = Path(config.test_dir)/'results'/'03001627'
225 | if extra_path.exists():
226 | extra_list = [x for x in extra_path.iterdir() if x.is_dir()]
227 | exists_list.extend(extra_list)
228 | for item in exists_list:
229 | shape = item.parts[-2] + '/' + item.name
230 | if shape in shape_list:
231 | shape_list.remove(shape)
232 |
233 | return load_data(config, views, shape_list, shuffle=False, batch_size=batch_size)
234 |
235 | def load_encode_data(config, views, batch_size=-1):
236 |
237 | print("Loading encoding data...")
238 |
239 | shape_list_file = open(os.path.join(config.data_dir, 'list.txt'), 'r')
240 | shape_list = shape_list_file.read().splitlines()
241 | shape_list_file.close()
242 |
243 | return load_data(config, views, shape_list, shuffle=False, batch_size=batch_size)
244 |
245 | def load_validate_data(config, views, batch_size=-1):
246 |
247 | print("Loading validation data...")
248 |
249 | shape_list_file = open(os.path.join(config.data_dir, 'validate-list.txt'), 'r')
250 | shape_list = shape_list_file.read().splitlines()
251 | shape_list_file.close()
252 |
253 | return load_data(config, views, shape_list, shuffle=False, batch_size=batch_size)
254 |
255 | def write_bin_data(file_name, data):
256 |
257 | path = os.path.dirname(file_name)
258 | if not os.path.exists(path):
259 | os.makedirs(path)
260 | data.tofile(file_name)
261 |
262 | def write_pfm_data(file_name, data):
263 |
264 | path = os.path.dirname(file_name)
265 | if not os.path.exists(path):
266 | os.makedirs(path)
267 | file = open(file_name, 'wb')
268 |
269 | if data.shape[2] == 1:
270 | file.write('Pf\n')
271 | elif data.shape[2] == 3:
272 | file.write('PF\n')
273 | else:
274 | raise ValueError('incorrect number of channels')
275 |
276 | file.write(('%d %d\n' % (data.shape[1], data.shape[0])))
277 | file.write('-1.0\n')
278 |
279 | data = np.flipud(data) # PFM format stores pixels from bottom to top...
280 | data.tofile(file)
281 |
282 | file.close()
--------------------------------------------------------------------------------
/monnet.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is part of the Sketch Modeling project.
3 |
4 | Copyright (c) 2017
5 | -Zhaoliang Lun (author of the code) / UMass-Amherst
6 |
7 | This is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | This software is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with this software. If not, see .
19 | """
20 |
21 |
22 | import tensorflow as tf
23 | import numpy as np
24 |
25 | import tensorflow.contrib.framework as tf_framework
26 |
27 | import time
28 | import os
29 | import math
30 |
31 | import data
32 | import image
33 | import network
34 | import layer
35 | import loss
36 | import reproject as rp
37 | import view as vw
38 |
39 | class MonNet(object):
40 |
41 | def __init__(self, config):
42 | self.config = config
43 |
44 | def build_network(self, names, sources, targets, masks, angles, views, is_training=False, is_validation=False, is_testing=False, is_encoding=False):
45 | """
46 | input:
47 | names : n x String shape names
48 | sources : n x H x W x C source images
49 | targets : (n*m) x H x W x C target images in m views (ground-truth)
50 | masks : (n*m) x H x W x 1 target boolean masks in m views (ground-truth)
51 | angles : (n*m) x 4 viewing angle parameters (m=1 for continuous view prediction)
52 | views : vw.Views view points information
53 | is_training : boolean whether it is in training routine
54 | is_validation : boolean whether it is handling validation data set
55 | is_testing : boolean whether it is in testing routine
56 | is_encoding : boolean whether it is encoding input
57 | """
58 |
59 | print('Building network...')
60 |
61 | source_size = sources.get_shape().as_list()
62 | if self.config.continuous_view:
63 | num_output_views = 1
64 | else:
65 | num_output_views = views.num_views
66 |
67 | # scope names
68 |
69 | var_scope_G = 'G_net'
70 | var_scope_D = 'D_net'
71 | bn_scope_G = 'G_bn'
72 | bn_scope_D = 'D_bn'
73 | train_summary_G_name = 'train_summary_G'
74 | train_summary_D_name = 'train_summary_D'
75 | valid_summary_name = 'valid_summary'
76 |
77 | # generator
78 |
79 | num_channels = targets.get_shape()[3].value
80 | if not self.config.continuous_view:
81 | with tf.variable_scope(var_scope_G):
82 | with tf_framework.arg_scope(layer.unet_scopes(bn_scope_G)):
83 | preds, features = network.generateUNet(sources, num_output_views, num_channels) # (n*m) x H x W x C ; n x D
84 | else:
85 | with tf.variable_scope(var_scope_G):
86 | with tf_framework.arg_scope(layer.cnet_scopes(bn_scope_G)):
87 | preds, features = network.generateCNet(sources, angles, num_channels) # n x H x W x C ; n x D
88 |
89 | if is_encoding:
90 | self.encode_names = names
91 | self.encode_features = features
92 | return # all stuffs below are irrelevant to encoding pass
93 |
94 | # extract prediction contents
95 |
96 | preds_content = tf.slice(preds, [0,0,0,0], [-1,-1,-1,num_channels-1])
97 | preds_mask = tf.slice(preds, [0,0,0,num_channels-1], [-1,-1,-1,1])
98 | preds = image.apply_mask(preds_content, preds_mask)
99 | targets_content = tf.slice(targets, [0,0,0,0], [-1,-1,-1,num_channels-1])
100 | targets_mask = tf.slice(targets, [0,0,0,num_channels-1], [-1,-1,-1,1])
101 | targets = image.apply_mask(targets_content, targets_mask)
102 | if self.config.predict_normal:
103 | preds_normal = tf.slice(preds_content, [0,0,0,0], [-1,-1,-1,3])
104 | preds_depth = tf.slice(preds_content, [0,0,0,3], [-1,-1,-1,1])
105 | targets_normal = tf.slice(targets_content, [0,0,0,0], [-1,-1,-1,3])
106 | targets_depth = tf.slice(targets_content, [0,0,0,3], [-1,-1,-1,1])
107 | else:
108 | preds_depth = preds_content
109 | preds_normal = tf.tile(tf.zeros_like(preds_depth), [1,1,1,3])
110 | targets_depth = targets_content
111 | targets_normal = tf.tile(tf.zeros_like(targets_depth), [1,1,1,3])
112 |
113 | # expand tensors
114 |
115 | sources_expanded = tf.reshape(tf.tile(sources, [1,num_output_views,1,1]),[-1,source_size[1],source_size[2],source_size[3]]) # (n*m) x H x W x C
116 |
117 | names_expanded = tf.reshape(tf.tile(tf.expand_dims(names,1),[1,num_output_views]),[-1])
118 | names_suffix = ["--%d" % view for batch in range(source_size[0]) for view in range(num_output_views)]
119 | names_expanded = tf.reduce_join([names_expanded, names_suffix], 0)
120 | self.names = names_expanded
121 |
122 | # discriminator
123 |
124 | if not self.config.no_adversarial:
125 | with tf.variable_scope(var_scope_D):
126 | with tf_framework.arg_scope(layer.unet_scopes(bn_scope_D)):
127 | disc_data = tf.concat([targets, preds], 0)
128 | disc_data = tf.concat([tf.concat([sources_expanded, sources_expanded], 0), disc_data], 3) # HACK: insert input data for discrimination in UNet
129 | probs = network.discriminate(disc_data) # (n*m*2)
130 |
131 | # losses
132 |
133 | # NOTE: learning hyper-parameters
134 | lambda_p = 1.0 # image loss
135 | lambda_a = 0.01 # adversarial loss
136 |
137 | dl = loss.compute_depth_loss(preds_depth, targets_depth, masks)
138 | nl = loss.compute_normal_loss(preds_normal, targets_normal, masks)
139 | ml = loss.compute_mask_loss(preds_mask, targets_mask)
140 | loss_g_p = dl + nl + ml
141 |
142 | if self.config.no_adversarial:
143 | loss_g_a = 0.0
144 | loss_d_r = 0.0
145 | loss_d_f = 0.0
146 | else:
147 | probs_targets, probs_preds = tf.split(probs, 2, axis=0) # (n*m)
148 | loss_g_a = tf.reduce_sum(-tf.log(tf.maximum(probs_preds, 1e-6)))
149 | loss_d_r = tf.reduce_sum(-tf.log(tf.maximum(probs_targets, 1e-6)))
150 | loss_d_f = tf.reduce_sum(-tf.log(tf.maximum(1.0-probs_preds, 1e-6)))
151 |
152 | loss_G = loss_g_p * lambda_p + loss_g_a * lambda_a
153 | loss_D = loss_d_r + loss_d_f
154 |
155 | if is_validation:
156 | self.valid_losses = tf.stack([loss_G, loss_g_p, loss_g_a, loss_D, loss_d_r, loss_d_f])
157 | self.valid_images = tf.stack([
158 | image.encode_raw_batch_images(preds),
159 | image.encode_raw_batch_images(targets),
160 | image.encode_raw_batch_images(preds_normal),
161 | image.encode_raw_batch_images(preds_depth),
162 | image.encode_raw_batch_images(preds_mask)])
163 | self.valid_summary_losses = tf.placeholder(tf.float32, shape=self.valid_losses.get_shape())
164 | vG_all, vG_p, vG_a, vD_all, vD_r, vD_f = tf.unstack(self.valid_summary_losses)
165 | tf.summary.scalar('vG_all', vG_all, collections=[valid_summary_name])
166 | tf.summary.scalar('vG_p', vG_p, collections=[valid_summary_name])
167 | tf.summary.scalar('vG_a', vG_a, collections=[valid_summary_name])
168 | tf.summary.scalar('vD_all', vD_all, collections=[valid_summary_name])
169 | tf.summary.scalar('vD_r', vD_r, collections=[valid_summary_name])
170 | tf.summary.scalar('vD_f', vD_f, collections=[valid_summary_name])
171 | self.valid_summary_op = tf.summary.merge_all(valid_summary_name)
172 | return # all stuffs below are irrelevant to validation pass
173 |
174 | self.train_losses_G = tf.stack([loss_G, loss_g_p, loss_g_a])
175 | self.train_losses_D = tf.stack([loss_D, loss_d_r, loss_d_f])
176 | tf.summary.scalar('G_all', loss_G, collections=[train_summary_G_name])
177 | tf.summary.scalar('G_p', loss_g_p, collections=[train_summary_G_name])
178 | tf.summary.scalar('G_a', loss_g_a, collections=[train_summary_G_name])
179 | tf.summary.scalar('D_all', loss_D, collections=[train_summary_D_name])
180 | tf.summary.scalar('D_r', loss_d_r, collections=[train_summary_D_name])
181 | tf.summary.scalar('D_f', loss_d_f, collections=[train_summary_D_name])
182 |
183 | # statistics on variables
184 |
185 | all_vars = tf.trainable_variables()
186 | all_vars_G = [var for var in all_vars if var_scope_G in var.name]
187 | all_vars_D = [var for var in all_vars if var_scope_D in var.name]
188 | #print('Num all vars: %d' % len(all_vars))
189 | #print('Num vars on G net: %d' % len(all_vars_G))
190 | #print('Num vars on D net: %d' % len(all_vars_D))
191 | num_params_G = 0
192 | num_params_D = 0
193 | # print('G vars:')
194 | for var in all_vars_G:
195 | num_params_G += np.prod(var.get_shape().as_list())
196 | # print(var.name, var.get_shape().as_list())
197 | # print('D vars:')
198 | for var in all_vars_D:
199 | num_params_D += np.prod(var.get_shape().as_list())
200 | # print(var.name, var.get_shape().as_list())
201 | #print('Num all params: %d + %d = %d' % (num_params_G, num_params_D, num_params_G+num_params_D))
202 | #input('pause')
203 |
204 | # optimization
205 |
206 | # NOTE: learning hyper-parameters
207 | init_learning_rate = 0.0001
208 | adam_beta1 = 0.9
209 | adam_beta2 = 0.999
210 | opt_step = tf.Variable(0, trainable=False)
211 | learning_rate = tf.train.exponential_decay(init_learning_rate, global_step=opt_step, decay_steps=10000, decay_rate=0.96, staircase=True)
212 |
213 | opt_G = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=adam_beta1, beta2=adam_beta2, name='ADAM_G')
214 | opt_D = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=adam_beta1, beta2=adam_beta2, name='ADAM_D')
215 | # opt_G = tf.train.GradientDescentOptimizer(learning_rate=learning_rate, name='SGD_G')
216 | # opt_D = tf.train.GradientDescentOptimizer(learning_rate=learning_rate, name='SGD_D')
217 |
218 | grad_G = opt_G.compute_gradients(loss_G, var_list=all_vars_G, colocate_gradients_with_ops=True)
219 | self.grad_G_placeholder = [(tf.placeholder(tf.float32, shape=grad[1].get_shape()), grad[1]) for grad in grad_G if grad[0] is not None]
220 | self.grad_G_list = [grad[0] for grad in grad_G if grad[0] is not None]
221 | self.update_G_op = opt_G.apply_gradients(self.grad_G_placeholder, global_step=opt_step) # only update opt_step in G net
222 |
223 | if not self.config.no_adversarial:
224 | grad_D = opt_D.compute_gradients(loss_D, var_list=all_vars_D, colocate_gradients_with_ops=True)
225 | self.grad_D_placeholder = [(tf.placeholder(tf.float32, shape=grad[1].get_shape()), grad[1]) for grad in grad_D if grad[0] is not None]
226 | self.grad_D_list = [grad[0] for grad in grad_D if grad[0] is not None]
227 | self.update_D_op = opt_D.apply_gradients(self.grad_D_placeholder)
228 |
229 | # visualization stuffs
230 |
231 | sources_original, sources_flipped = tf.split(sources_expanded, 2, axis=3)
232 | if len(self.config.sketch_views) == 1: # single input
233 | sources_front = sources_original
234 | sources_side = tf.ones_like(sources_front) # fake side sketch
235 | sources_top = tf.ones_like(sources_front) # fake top sketch
236 | elif len(self.config.sketch_views) == 2: # double input
237 | sources_front, sources_side = tf.split(sources_original, 2, axis=3)
238 | sources_top = tf.ones_like(sources_front) # fake top sketch
239 | elif len(self.config.sketch_views) == 3: # triple input
240 | sources_front, sources_side, sources_top = tf.split(sources_original, 3, axis=3)
241 | if sources_front.get_shape()[3].value == 1 and targets.get_shape()[3].value == 4:
242 | alpha_front = tf.ones_like(sources_front)
243 | alpha_side = tf.ones_like(sources_side)
244 | alpha_top = tf.ones_like(sources_top)
245 | rgb_front = image.convert_to_rgb(sources_front, channels=3)
246 | rgb_side = image.convert_to_rgb(sources_side, channels=3)
247 | rgb_top = image.convert_to_rgb(sources_top, channels=3)
248 | sources_front = tf.concat([rgb_front, alpha_front], 3)
249 | sources_side = tf.concat([rgb_side, alpha_side], 3)
250 | sources_top = tf.concat([rgb_top, alpha_top], 3)
251 |
252 | input_row = tf.concat([sources_front, sources_side], 2)
253 | output_row = tf.concat([targets, preds], 2)
254 |
255 | result_tile = tf.concat([input_row, output_row], 1)
256 | result_tile = image.saturate_image(image.unnormalize_image(result_tile))
257 |
258 | tf.summary.image('result', result_tile, 12, [train_summary_G_name])
259 |
260 | self.train_summary_G_op = tf.summary.merge_all(train_summary_G_name)
261 | self.train_summary_D_op = tf.summary.merge_all(train_summary_D_name)
262 |
263 | # output images
264 |
265 | num_sketch_views = len(self.config.sketch_views)
266 | if num_sketch_views==1:
267 | all_input_row = sources_front
268 | elif num_sketch_views==2:
269 | all_input_row = tf.concat([sources_front, sources_side], 2)
270 | elif num_sketch_views==3:
271 | all_input_row = tf.concat([sources_front, sources_side, sources_top], 2)
272 | img_input = image.saturate_image(image.unnormalize_image(all_input_row, maxval=65535.0), dtype=tf.uint16)
273 | img_gt = image.saturate_image(image.unnormalize_image(targets, maxval=65535.0), dtype=tf.uint16)
274 | img_output = image.saturate_image(image.unnormalize_image(preds, maxval=65535.0), dtype=tf.uint16)
275 | png_input = image.encode_batch_images(img_input)
276 | png_gt = image.encode_batch_images(img_gt)
277 | png_output = image.encode_batch_images(img_output)
278 |
279 | img_normal = image.saturate_image(image.unnormalize_image(preds_normal, maxval=65535.0), dtype=tf.uint16)
280 | img_depth = image.saturate_image(image.unnormalize_image(preds_depth, maxval=65535.0), dtype=tf.uint16)
281 | img_mask = image.saturate_image(image.unnormalize_image(preds_mask, maxval=65535.0), dtype=tf.uint16)
282 | png_normal = image.encode_batch_images(img_normal)
283 | png_depth = image.encode_batch_images(img_depth)
284 | png_mask = image.encode_batch_images(img_mask)
285 | self.pngs = tf.stack([png_input, png_gt, png_output, png_normal, png_depth, png_mask])
286 |
287 | # output results
288 |
289 | pixel_shape = preds.get_shape().as_list()
290 | num_pixels = np.prod(pixel_shape[1:])
291 | self.errors = tf.reduce_sum(tf.abs(preds-targets), [1,2,3]) / num_pixels # just a quick check
292 | self.results = preds
293 |
294 | # batch normalization
295 |
296 | bn_G_collection = tf.get_collection(bn_scope_G)
297 | bn_D_collection = tf.get_collection(bn_scope_D)
298 | self.bn_G_op = tf.group(*bn_G_collection)
299 | self.bn_D_op = tf.group(*bn_D_collection)
300 |
301 | def train(self, sess, views, num_train_shapes, num_valid_shapes):
302 |
303 | print('Training...')
304 |
305 | ckpt = tf.train.get_checkpoint_state(self.config.train_dir)
306 | init_op = tf.global_variables_initializer()
307 | sess.run(init_op)
308 | if ckpt and ckpt.model_checkpoint_path:
309 | self.saver = tf.train.Saver(keep_checkpoint_every_n_hours=10.0, max_to_keep=2)
310 | self.saver.restore(sess, ckpt.model_checkpoint_path)
311 | try:
312 | self.step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
313 | except ValueError:
314 | self.step = 0
315 | else:
316 | self.saver = tf.train.Saver(tf.global_variables(), keep_checkpoint_every_n_hours=10.0, max_to_keep=2)
317 | self.step = 0
318 |
319 | coord = tf.train.Coordinator()
320 | threads = tf.train.start_queue_runners(sess=sess, coord=coord)
321 | self.summarizer = tf.summary.FileWriter(self.config.train_dir, sess.graph)
322 |
323 | print_interval = 40 // self.config.batch_size # steps
324 | update_interval = 40 // self.config.batch_size # steps
325 | summary_interval = 200 # steps
326 | validate_interval = 200 # steps
327 | output_interval = 1000 # steps
328 | checkpoint_interval = 1000 # steps
329 |
330 | print('Start iterating...')
331 |
332 | start_time = time.time()
333 |
334 | train_D_net = not self.config.no_adversarial
335 | batch_grad_G_list = None
336 | batch_grad_D_list = None
337 | batch_losses_G = None
338 | batch_losses_D = None
339 | step_losses_G = None
340 | step_losses_D = None
341 |
342 | while True:
343 |
344 | # compute epochs
345 |
346 | epochs = 1.0*(self.step+1)*self.config.batch_size/num_train_shapes
347 | do_print = ((self.step+1) % print_interval == 0)
348 | do_update = ((self.step+1) % update_interval == 0)
349 | do_validate = ((self.step+1) % validate_interval == 0)
350 | do_summary = ((self.step+1) % summary_interval == 0)
351 | do_checkpoint = ((self.step+1) % checkpoint_interval == 0)
352 | do_output = ((self.step+1) % output_interval == 0)
353 |
354 | # training networks
355 |
356 | step_G_list = sess.run(self.grad_G_list + [self.bn_G_op, self.train_losses_G])
357 | step_grad_G_list = step_G_list[:-2]
358 | step_losses_G = step_G_list[-1] / self.config.batch_size
359 | batch_grad_G_list = self.cumulate_gradients(batch_grad_G_list, step_grad_G_list)
360 |
361 | if train_D_net:
362 | step_D_list = sess.run(self.grad_D_list + [self.bn_D_op, self.train_losses_D])
363 | step_grad_D_list = step_D_list[:-2]
364 | step_losses_D = step_D_list[-1] / self.config.batch_size
365 | batch_grad_D_list = self.cumulate_gradients(batch_grad_D_list, step_grad_D_list)
366 | else:
367 | if step_losses_D is None:
368 | step_losses_D = [0.0, 0.0, 0.0]
369 |
370 | batch_losses_G = step_losses_G if batch_losses_G is None else batch_losses_G+step_losses_G
371 | batch_losses_D = step_losses_D if batch_losses_D is None else batch_losses_D+step_losses_D
372 |
373 | # update gradients
374 |
375 | if do_update:
376 | grad_G_dict = {}
377 | for k in range(len(self.grad_G_placeholder)):
378 | grad_G_dict[self.grad_G_placeholder[k][0]] = batch_grad_G_list[k] / update_interval
379 | sess.run(self.update_G_op, feed_dict=grad_G_dict)
380 | batch_grad_G_list = None
381 |
382 | if train_D_net:
383 | grad_D_dict = {}
384 | for k in range(len(self.grad_D_placeholder)):
385 | grad_D_dict[self.grad_D_placeholder[k][0]] = batch_grad_D_list[k] / update_interval
386 | sess.run(self.update_D_op, feed_dict=grad_D_dict)
387 | batch_grad_D_list = None
388 |
389 | if not self.config.no_adversarial:
390 | batch_losses_G = batch_losses_G / update_interval
391 | if batch_losses_D is not None:
392 | batch_losses_D = batch_losses_D / update_interval
393 | train_D_net = (batch_losses_D[0] > batch_losses_G[2] * 0.1) # NOTE: subscript
394 | batch_losses_G = None
395 | batch_losses_D = None
396 |
397 | # validation
398 |
399 | if do_validate:
400 | self.validate_loss(sess, num_valid_shapes)
401 |
402 | if do_output:
403 | self.validate_output(sess, num_valid_shapes, epochs)
404 |
405 | # log
406 |
407 | if do_summary:
408 | summary_G_str = sess.run(self.train_summary_G_op)
409 | self.summarizer.add_summary(summary_G_str, self.step)
410 | if train_D_net:
411 | summary_D_str = sess.run(self.train_summary_D_op)
412 | self.summarizer.add_summary(summary_D_str, self.step)
413 |
414 | if do_checkpoint:
415 | self.saver.save(sess, os.path.join(self.config.train_dir,'model.ckpt'), global_step=self.step+1)
416 |
417 | if do_print:
418 | now_time = time.time()
419 | batch_duration = now_time - start_time
420 | start_time = now_time
421 | log_str_1 = 'Step %7d: %5.1f sec, epoch: %7.2f, ' % (self.step+1, batch_duration, epochs)
422 | log_str_2 = 'losses: %7.3g, %7.3g, %7.3g, %7.3g, %7.3g, %7.3g;' % \
423 | (step_losses_G[0], step_losses_G[1], step_losses_G[2], step_losses_D[0], step_losses_D[1], step_losses_D[2])
424 | print(log_str_1, end='')
425 | print(log_str_2)
426 | log_file_name = os.path.join(self.config.train_dir,'log.txt')
427 | with open(log_file_name, 'a') as log_file:
428 | log_file.write(log_str_1+log_str_2+'\n')
429 |
430 | if epochs >= self.config.max_epochs:
431 | break
432 |
433 | self.step += 1
434 |
435 | coord.request_stop()
436 | coord.join(threads)
437 |
438 | def test(self, sess, views, num_shapes):
439 |
440 | print('Testing...')
441 |
442 | self.saver = tf.train.Saver()
443 | ckpt = tf.train.get_checkpoint_state(self.config.train_dir)
444 | if ckpt and ckpt.model_checkpoint_path:
445 | self.saver.restore(sess, ckpt.model_checkpoint_path)
446 | try:
447 | self.step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
448 | except ValueError:
449 | self.step = 0
450 | else:
451 | print('Cannot find any checkpoint file')
452 | return
453 |
454 | coord = tf.train.Coordinator()
455 | threads = tf.train.start_queue_runners(sess=sess, coord=coord)
456 | self.summarizer = tf.summary.FileWriter(self.config.test_dir, sess.graph)
457 |
458 | output_count = 0
459 | output_prefix = 'dn14'
460 | output_images_folder = 'images'
461 | output_results_folder = 'results'
462 |
463 | log_file_name = os.path.join(self.config.test_dir,'log.txt')
464 | log_file = open(log_file_name, 'a')
465 |
466 | started = False
467 | finished = False
468 | last_shape_name = ''
469 | last_view_name = ''
470 | while not finished:
471 | names,results,errors,images = sess.run([self.names, self.results, self.errors, self.pngs])
472 | for k in range(len(names)):
473 | shape_name, view_name = names[k].decode('utf8').split('--')
474 | if last_shape_name == shape_name:
475 | view_name = ('%s' % (int(last_view_name)+1))
476 | last_shape_name = shape_name
477 | last_view_name = view_name
478 | print('Processed %d: %s--%s %f' % (output_count, shape_name, view_name, errors[k]))
479 |
480 | if view_name == '0' and started:
481 | log_file.write('\n')
482 | started = True
483 | log_file.write('%6f ' % errors[k])
484 |
485 | # export images
486 | name_input = os.path.join(self.config.test_dir, output_images_folder, shape_name, 'input.png')
487 | image.write_image(name_input, images[0, k])
488 | name_gt = os.path.join(self.config.test_dir, output_images_folder, shape_name, ('gt-'+output_prefix+'--'+view_name+'.png'))
489 | name_output = os.path.join(self.config.test_dir, output_images_folder, shape_name, ('pred-'+output_prefix+'--'+view_name+'.png'))
490 | image.write_image(name_gt, images[1, k])
491 | image.write_image(name_output, images[2, k])
492 |
493 | name_normal = os.path.join(self.config.test_dir, output_images_folder, shape_name, ('normal-'+output_prefix+'--'+view_name+'.png'))
494 | name_depth = os.path.join(self.config.test_dir, output_images_folder, shape_name, ('depth-'+output_prefix+'--'+view_name+'.png'))
495 | name_mask = os.path.join(self.config.test_dir, output_images_folder, shape_name, ('mask-'+output_prefix+'--'+view_name+'.png'))
496 | image.write_image(name_normal, images[3, k])
497 | image.write_image(name_depth, images[4, k])
498 | image.write_image(name_mask, images[5, k])
499 |
500 | # export results
501 | name_output = os.path.join(self.config.test_dir, output_results_folder, shape_name, (output_prefix+'-'+view_name+'.png'))
502 | image.write_image(name_output, images[2, k])
503 |
504 | # check termination
505 | output_count += 1
506 | if output_count >= num_shapes * views.num_views:
507 | finished = True
508 | break
509 |
510 | coord.request_stop()
511 | coord.join(threads)
512 |
513 | def encode(self, sess, views, num_shapes):
514 |
515 | print('Encoding...')
516 |
517 | self.saver = tf.train.Saver()
518 | ckpt = tf.train.get_checkpoint_state(self.config.train_dir)
519 | if ckpt and ckpt.model_checkpoint_path:
520 | self.saver.restore(sess, ckpt.model_checkpoint_path)
521 | self.step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
522 | else:
523 | print('Cannot find any checkpoint file')
524 | return
525 |
526 | coord = tf.train.Coordinator()
527 | threads = tf.train.start_queue_runners(sess=sess, coord=coord)
528 | self.summarizer = tf.summary.FileWriter(self.config.encode_dir, sess.graph)
529 |
530 | output_count = 0
531 | output_folder = 'features'
532 |
533 | finished = False
534 | while not finished:
535 | names,features = sess.run([self.encode_names, self.encode_features])
536 | for k in range(len(names)):
537 | shape_name = names[k].decode('utf8')
538 | print('Processed %d: %s' % (output_count, shape_name))
539 |
540 | # export results
541 | name_output = os.path.join(self.config.encode_dir, output_folder, (shape_name+'.bin'))
542 | data.write_bin_data(name_output, features[k])
543 |
544 | # check termination
545 | output_count += 1
546 | if output_count >= num_shapes:
547 | finished = True
548 | break
549 |
550 | coord.request_stop()
551 | coord.join(threads)
552 |
553 | def validate_loss(self, sess, num_shapes):
554 |
555 | num_processed_shapes = 0
556 | cum_losses = None
557 | while num_processed_shapes < num_shapes:
558 | losses = sess.run(self.valid_losses)
559 | losses = np.array(losses)
560 | cum_losses = losses if cum_losses is None else cum_losses+losses
561 | num_processed_shapes += self.config.batch_size
562 | cum_losses /= num_processed_shapes
563 |
564 | print('===== validation loss: %.3g' % cum_losses[0])
565 |
566 | summary_str = sess.run(self.valid_summary_op, feed_dict={self.valid_summary_losses:cum_losses})
567 | self.summarizer.add_summary(summary_str, self.step)
568 |
569 | def validate_output(self, sess, num_shapes, epochs):
570 |
571 | print('===== validation output')
572 | valid_results_folder = 'epoch-%.2f' % epochs
573 | names, images = sess.run([self.names, self.valid_images])
574 |
575 | for k in range(len(names)):
576 | shape_name, view_name = names[k].decode('utf8').split('--')
577 | if view_name == '0':
578 | print(shape_name)
579 |
580 | name_output = os.path.join(self.config.train_dir, valid_results_folder, shape_name, ('output--'+view_name+'.png'))
581 | name_gt = os.path.join(self.config.train_dir, valid_results_folder, shape_name, ('gt--'+view_name+'.png'))
582 | image.write_image(name_output, images[0, k])
583 | image.write_image(name_gt, images[1, k])
584 |
585 | name_normal = os.path.join(self.config.train_dir, valid_results_folder, shape_name, ('normal--'+view_name+'.png'))
586 | name_depth = os.path.join(self.config.train_dir, valid_results_folder, shape_name, ('depth--'+view_name+'.png'))
587 | name_mask = os.path.join(self.config.train_dir, valid_results_folder, shape_name, ('mask--'+view_name+'.png'))
588 | image.write_image(name_normal, images[2, k])
589 | image.write_image(name_depth, images[3, k])
590 | image.write_image(name_mask, images[4, k])
591 |
592 | # loop over all remaining shapes in the queue...
593 | num_processed_shapes = self.config.batch_size
594 | while num_processed_shapes < num_shapes:
595 | sess.run(self.names)
596 | num_processed_shapes += self.config.batch_size
597 |
598 | def cumulate_gradients(self, cum_grads, grads):
599 | if cum_grads is None:
600 | cum_grads = grads
601 | else:
602 | for k in range(len(grads)):
603 | cum_grads[k] += grads[k]
604 | return cum_grads
--------------------------------------------------------------------------------