├── .gitignore ├── .idea ├── TensorFlow_DCIGN.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── workspace.xml ├── Bunch.py ├── DCIGNModel.py ├── IGNModel.py ├── LICENSE ├── Model.py ├── README.md ├── __init__.py ├── activation_functions.py ├── autoencoder.py ├── data_preprocessing ├── generate.command ├── movie.command ├── resize.command ├── resize.sh ├── resize_grey.command ├── resize_no_gun.command ├── tar_all.command └── to32_32.command ├── experiments.py ├── input.py ├── metrics.py ├── model_interpreter.py ├── model_interpreter_test.py ├── network_utils.py ├── tools ├── __init__.py ├── checkpoint_utils.py ├── freeze_graph.py ├── freeze_graph_test.py ├── graph_metrics.py ├── graph_metrics_test.py ├── inspect_checkpoint.py ├── strip_unused.py └── strip_unused_test.py ├── utils.py ├── video_builder.py ├── visualization.py └── visualize_latest.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | 92 | # Custom 93 | .idea/ 94 | tmp/ 95 | -------------------------------------------------------------------------------- /.idea/TensorFlow_DCIGN.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 7 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 11 | 13 | 14 | 16 | 17 | 18 | 1484656314255 19 | 23 | 24 | 25 | 26 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /Bunch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class Bunch(object): 5 | def __init__(self, **kwds): 6 | self.__dict__.update(kwds) 7 | 8 | def __eq__(self, other): 9 | return self.__dict__ == other.__dict__ 10 | 11 | def __str__(self): 12 | string = "BUNCH {" + str(self.__dict__)[2:] 13 | string = string.replace("': ", ":") 14 | string = string.replace(", '", ", ") 15 | return string 16 | 17 | def __repr__(self): 18 | return str(self) 19 | 20 | def to_file_name(self, folder=None, ext=None): 21 | res = str(self.__dict__)[2:-1] 22 | res = res.replace("'", "") 23 | res = res.replace(": ", ".") 24 | parts = res.split(', ') 25 | res = '_'.join(sorted(parts)) 26 | 27 | if ext is not None: 28 | res = '%s.%s' % (res, ext) 29 | if folder is not None: 30 | res = os.path.join(folder, res) 31 | return res 32 | 33 | 34 | if __name__ == '__main__': 35 | b = Bunch(x=5, y='something', other=9.0) 36 | print(b) 37 | print(b.to_file_name()) 38 | print(b.to_file_name('./here', 'txt')) 39 | 40 | 41 | -------------------------------------------------------------------------------- /DCIGNModel.py: -------------------------------------------------------------------------------- 1 | """MNIST Autoencoder. """ 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from six.moves import xrange # pylint: disable=redefined-builtin 8 | import tensorflow as tf 9 | import json, os, re, math 10 | import numpy as np 11 | import utils as ut 12 | import input as inp 13 | import tools.checkpoint_utils as ch_utils 14 | import activation_functions as act 15 | import visualization as vis 16 | import prettytensor as pt 17 | import prettytensor.bookkeeper as bookkeeper 18 | import deconv 19 | from tensorflow.python.ops import gradients 20 | from prettytensor.tutorial import data_utils 21 | import IGNModel 22 | 23 | FLAGS = tf.app.flags.FLAGS 24 | 25 | DEV = False 26 | 27 | 28 | class DCIGNModel(IGNModel.IGNModel): 29 | model_id = 'dcign' 30 | 31 | def _build_encoder(self): 32 | """Construct encoder network: placeholders, operations, optimizer""" 33 | self._input = tf.placeholder(tf.float32, self._batch_shape, name='input') 34 | self._encoding = tf.placeholder(tf.float32, (FLAGS.batch_size, self.layer_narrow), name='encoding') 35 | 36 | self._encode = (pt.wrap(self._input) 37 | .flatten() 38 | .fully_connected(self.layer_encoder, name='enc_hidden') 39 | .fully_connected(self.layer_narrow, name='narrow')) 40 | 41 | self._encode = pt.wrap(self._input) 42 | self._encode = self._encode.conv2d(5, 32, stride=2) 43 | print(self._encode.get_shape()) 44 | self._encode = self._encode.conv2d(5, 64, stride=2) 45 | print(self._encode.get_shape()) 46 | self._encode = self._encode.conv2d(5, 128, stride=2) 47 | print(self._encode.get_shape()) 48 | self._encode = (self._encode.dropout(0.9). 49 | flatten(). 50 | fully_connected(self.layer_narrow, activation_fn=None)) 51 | 52 | # variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.encoder_scope) 53 | self._encoder_loss = self._encode.l1_regression(pt.wrap(self._encoding)) 54 | ut.print_info('new learning rate: %.8f (%f)' % (FLAGS.learning_rate/FLAGS.batch_size, FLAGS.learning_rate)) 55 | self._opt_encoder = self._optimizer(learning_rate=FLAGS.learning_rate/FLAGS.batch_size) 56 | self._train_encoder = self._opt_encoder.minimize(self._encoder_loss) 57 | 58 | def _build_decoder(self, weight_init=tf.truncated_normal): 59 | """Construct decoder network: placeholders, operations, optimizer, 60 | extract gradient back-prop for encoding layer""" 61 | self._clamped = tf.placeholder(tf.float32, (FLAGS.batch_size, self.layer_narrow)) 62 | self._reconstruction = tf.placeholder(tf.float32, self._batch_shape) 63 | 64 | clamped_init = np.zeros((FLAGS.batch_size, self.layer_narrow), dtype=np.float32) 65 | self._clamped_variable = tf.Variable(clamped_init, name='clamped') 66 | self._assign_clamped = tf.assign(self._clamped_variable, self._clamped) 67 | 68 | self._decode = pt.wrap(self._clamped_variable) 69 | # self._decode = self._decode.reshape([FLAGS.batch_size, 1, 1, self.layer_narrow]) 70 | print(self._decode.get_shape()) 71 | self._decode = self._decode.fully_connected(7200) 72 | self._decode = self._decode.reshape([FLAGS.batch_size, 1, 1, 7200]) 73 | self._decode = self._decode.deconv2d((10, 20), 128, edges='VALID') 74 | print(self._decode.get_shape()) 75 | self._decode = self._decode.deconv2d(5, 64, stride=2) 76 | print(self._decode.get_shape()) 77 | self._decode = self._decode.deconv2d(5, 32, stride=2) 78 | print(self._decode.get_shape()) 79 | self._decode = self._decode.deconv2d(5, 3, stride=2, activation_fn=tf.nn.sigmoid) 80 | print(self._decode.get_shape()) 81 | 82 | # variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.decoder_scope) 83 | self._decoder_loss = self._decode.l2_regression(pt.wrap(self._reconstruction)) 84 | self._opt_decoder = self._optimizer(learning_rate=FLAGS.learning_rate/FLAGS.batch_size) 85 | self._train_decoder = self._opt_decoder.minimize(self._decoder_loss) 86 | 87 | self._clamped_grad, = tf.gradients(self._decoder_loss, [self._clamped_variable]) 88 | 89 | 90 | def parse_params(): 91 | params = {} 92 | for i, param in enumerate(sys.argv): 93 | if '-' in param: 94 | params[param[1:]] = sys.argv[i+1] 95 | print(params) 96 | return params 97 | 98 | 99 | if __name__ == '__main__': 100 | epochs = 500 101 | import sys 102 | 103 | FLAGS.save_every = 5 104 | FLAGS.save_encodings_every = 2 105 | 106 | 107 | model = DCIGNModel() 108 | args = dict([arg.split('=', maxsplit=1) for arg in sys.argv[1:]]) 109 | if len(args) == 0: 110 | global DEV 111 | DEV = False 112 | print('DEVELOPMENT MODE ON') 113 | print(args) 114 | if 'epochs' in args: 115 | epochs = int(args['epochs']) 116 | ut.print_info('epochs: %d' % epochs, color=36) 117 | if 'sigma' in args: 118 | FLAGS.sigma = int(args['sigma']) 119 | if 'suffix' in args: 120 | FLAGS.suffix = args['suffix'] 121 | if 'input' in args: 122 | parts = FLAGS.input_path.split('/') 123 | parts[-3] = args['input'] 124 | FLAGS.input_path = '/'.join(parts) 125 | ut.print_info('input %s' % FLAGS.input_path, color=36) 126 | if 'h' in args: 127 | layers = list(map(int, args['h'].split('/'))) 128 | ut.print_info('layers %s' % str(layers), color=36) 129 | model.set_layer_sizes(layers) 130 | if 'divider' in args: 131 | FLAGS.drag_divider = float(args['divider']) 132 | if 'lr' in args: 133 | FLAGS.learning_rate = float(args['lr']) 134 | 135 | model.train(epochs) 136 | -------------------------------------------------------------------------------- /IGNModel.py: -------------------------------------------------------------------------------- 1 | """MNIST Autoencoder. """ 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from six.moves import xrange # pylint: disable=redefined-builtin 8 | import tensorflow as tf 9 | import json, os, re, math 10 | import numpy as np 11 | import utils as ut 12 | import input as inp 13 | import tools.checkpoint_utils as ch_utils 14 | import activation_functions as act 15 | import visualization as vis 16 | import prettytensor as pt 17 | import Model as m 18 | 19 | tf.app.flags.DEFINE_float('gradient_proportion', 5.0, 'Proportion of gradietn mixture RECO/DRAG') 20 | tf.app.flags.DEFINE_integer('sequence_length', 25, 'size of the 1-variable-variation sequencies') 21 | 22 | FLAGS = tf.app.flags.FLAGS 23 | 24 | DEV = False 25 | 26 | 27 | def _clamp(encoding, filter): 28 | filter_neg = np.ones(len(filter), dtype=filter.dtype) - filter 29 | # print('\nf_n', filter_neg) 30 | avg = encoding.mean(axis=0)*filter_neg 31 | # print('avg', avg, encoding[0]) 32 | # print('avg', encoding.mean(axis=0), encoding[-1], encoding[-1]-avg) 33 | grad = encoding*filter_neg - avg 34 | encoding = encoding * filter + avg 35 | # print('enc', encoding[0], encoding[1]) 36 | # print(np.hstack((encoding, grad))) 37 | # print('vae', grad[0], grad[1]) 38 | return encoding, grad 39 | 40 | 41 | def _declamp_grad(vae_grad, reco_grad, filter): 42 | # print('vae, reco', np.abs(vae_grad).mean(), np.abs((reco_grad*filter)).mean()) 43 | res = vae_grad/FLAGS.gradient_proportion + reco_grad*filter 44 | # res = vae_grad + reco_grad*filter 45 | #print('\nvae: %s\nrec: %s\nres %s' % (ut.print_float_list(vae_grad[1]), 46 | # ut.print_float_list(reco_grad[1]), 47 | # ut.print_float_list(res[0]))) 48 | return res 49 | 50 | 51 | class IGNModel(m.Model): 52 | model_id = 'ign' 53 | decoder_scope = 'dec' 54 | encoder_scope = 'enc' 55 | 56 | layer_narrow = 2 57 | layer_encoder = 40 58 | layer_decoder = 40 59 | 60 | _image_shape = None 61 | _batch_shape = None 62 | 63 | # placeholders 64 | _input = None 65 | _encoding = None 66 | 67 | _clamped = None 68 | _reconstruction = None 69 | _clamped_grad = None 70 | 71 | # variables 72 | _clamped_variable = None 73 | 74 | # operations 75 | _encode = None 76 | _encoder_loss = None 77 | _opt_encoder = None 78 | _train_encoder = None 79 | 80 | _decode = None 81 | _decoder_loss = None 82 | _opt_decoder = None 83 | _train_decoder = None 84 | 85 | _step = None 86 | _current_step = None 87 | _visualize_op = None 88 | 89 | def __init__(self, 90 | weight_init=None, 91 | activation=act.sigmoid, 92 | optimizer=tf.train.AdamOptimizer): 93 | super(IGNModel, self).__init__() 94 | FLAGS.batch_size = FLAGS.sequence_length 95 | self._weight_init = weight_init 96 | self._activation = activation 97 | self._optimizer = optimizer 98 | if FLAGS.load_from_checkpoint: 99 | self.load_meta(FLAGS.load_from_checkpoint) 100 | 101 | def get_layer_info(self): 102 | return [self.layer_encoder, self.layer_narrow, self.layer_decoder] 103 | 104 | def get_meta(self, meta=None): 105 | meta = super(IGNModel, self).get_meta(meta=meta) 106 | meta['div'] = FLAGS.gradient_proportion 107 | return meta 108 | 109 | def load_meta(self, save_path): 110 | meta = super(IGNModel, self).load_meta(save_path) 111 | self._weight_init = meta['init'] 112 | self._optimizer = tf.train.AdadeltaOptimizer \ 113 | if 'Adam' in meta['opt'] \ 114 | else tf.train.AdadeltaOptimizer 115 | self._activation = act.sigmoid 116 | self.layer_encoder = meta['h'][0] 117 | self.layer_narrow = meta['h'][1] 118 | self.layer_decoder = meta['h'][2] 119 | FLAGS.gradient_proportion = float(meta['div']) 120 | ut.configure_folders(FLAGS, self.get_meta()) 121 | return meta 122 | 123 | # MODEL 124 | 125 | def build_model(self): 126 | tf.reset_default_graph() 127 | self._batch_shape = inp.get_batch_shape(FLAGS.batch_size, FLAGS.input_path) 128 | self._current_step = tf.Variable(0, trainable=False, name='global_step') 129 | self._step = tf.assign(self._current_step, self._current_step + 1) 130 | with pt.defaults_scope(activation_fn=self._activation.func): 131 | with pt.defaults_scope(phase=pt.Phase.train): 132 | with tf.variable_scope(self.encoder_scope): 133 | self._build_encoder() 134 | with tf.variable_scope(self.decoder_scope): 135 | self._build_decoder() 136 | 137 | def _build_encoder(self): 138 | """Construct encoder network: placeholders, operations, optimizer""" 139 | self._input = tf.placeholder(tf.float32, self._batch_shape, name='input') 140 | self._encoding = tf.placeholder(tf.float32, (FLAGS.batch_size, self.layer_narrow), name='encoding') 141 | 142 | self._encode = (pt.wrap(self._input) 143 | .flatten() 144 | .fully_connected(self.layer_encoder, name='enc_hidden') 145 | .fully_connected(self.layer_narrow, name='narrow')) 146 | 147 | # variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.encoder_scope) 148 | self._encoder_loss = self._encode.l1_regression(pt.wrap(self._encoding)) 149 | ut.print_info('new learning rate: %.8f (%f)' % (FLAGS.learning_rate/FLAGS.batch_size, FLAGS.learning_rate)) 150 | self._opt_encoder = self._optimizer(learning_rate=FLAGS.learning_rate/FLAGS.batch_size) 151 | self._train_encoder = self._opt_encoder.minimize(self._encoder_loss) 152 | 153 | def _build_decoder(self, weight_init=tf.truncated_normal): 154 | """Construct decoder network: placeholders, operations, optimizer, 155 | extract gradient back-prop for encoding layer""" 156 | self._clamped = tf.placeholder(tf.float32, (FLAGS.batch_size, self.layer_narrow)) 157 | self._reconstruction = tf.placeholder(tf.float32, self._batch_shape) 158 | 159 | clamped_init = np.zeros((FLAGS.batch_size, self.layer_narrow), dtype=np.float32) 160 | self._clamped_variable = tf.Variable(clamped_init, name='clamped') 161 | self._assign_clamped = tf.assign(self._clamped_variable, self._clamped) 162 | 163 | # http://stackoverflow.com/questions/40194389/how-to-propagate-gradient-into-a-variable-after-assign-operation 164 | self._decode = ( 165 | pt.wrap(self._clamped_variable) 166 | .fully_connected(self.layer_decoder, name='decoder_1') 167 | .fully_connected(np.prod(self._image_shape), init=weight_init, name='output') 168 | .reshape(self._batch_shape)) 169 | 170 | # variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.decoder_scope) 171 | self._decoder_loss = self._build_reco_loss(self._reconstruction) 172 | self._opt_decoder = self._optimizer(learning_rate=FLAGS.learning_rate) 173 | self._train_decoder = self._opt_decoder.minimize(self._decoder_loss) 174 | 175 | self._clamped_grad, = tf.gradients(self._decoder_loss, [self._clamped_variable]) 176 | 177 | # DATA 178 | 179 | def fetch_datasets(self, activation_func_bounds): 180 | original_data, filters = inp.get_images(FLAGS.input_path) 181 | assert len(filters) == len(original_data) 182 | original_data, filters = self.bloody_hack_filterbatches(original_data, filters) 183 | ut.print_info('shapes. data, filters: %s' % str((original_data.shape, filters.shape))) 184 | 185 | original_data = inp.rescale_ds(original_data, activation_func_bounds.min, activation_func_bounds.max) 186 | self._image_shape = inp.get_image_shape(FLAGS.input_path) 187 | 188 | if DEV: 189 | original_data = original_data[:300] 190 | 191 | self.epoch_size = math.ceil(len(original_data) / FLAGS.batch_size) 192 | self.test_size = math.ceil(len(original_data) / FLAGS.batch_size) 193 | return original_data, filters 194 | 195 | def bloody_hack_filterbatches(self, original_data, filters): 196 | print(filters) 197 | survivers = np.zeros(len(filters), dtype=np.uint8) 198 | j, prev = 0, None 199 | for _, f in enumerate(filters): 200 | if prev is None or prev[0] == f[0] and prev[1] == f[1]: 201 | j += 1 202 | else: 203 | k = j // FLAGS.batch_size 204 | for i in range(k): 205 | start = _ - j + math.ceil(j / k * i) 206 | survivers[start:start + FLAGS.batch_size] += 1 207 | # print(j, survivers[_-j:_]) 208 | j = 0 209 | prev = f 210 | original_data = np.asarray([x for i, x in enumerate(original_data) if survivers[i] > 0]) 211 | filters = np.asarray([x for i, x in enumerate(filters) if survivers[i] > 0]) 212 | return original_data, filters 213 | 214 | def _get_epoch_dataset(self): 215 | ds, filters = self._get_blurred_dataset(), self._filters 216 | # permute 217 | (train_set, filters), permutation = inp.permute_data_in_series((ds, filters), FLAGS.batch_size, allow_shift=False) 218 | # construct feed 219 | feed = pt.train.feed_numpy(FLAGS.batch_size, train_set, filters) 220 | return feed, permutation 221 | 222 | # TRAIN 223 | 224 | def train(self, epochs_to_train=5): 225 | meta = self.get_meta() 226 | ut.print_time('train started: \n%s' % ut.to_file_name(meta)) 227 | # return meta, np.random.randn(epochs_to_train) 228 | ut.configure_folders(FLAGS, meta) 229 | 230 | self._dataset, self._filters = self.fetch_datasets(self._activation) 231 | self.build_model() 232 | self._register_training_start() 233 | 234 | with tf.Session() as sess: 235 | sess.run(tf.initialize_all_variables()) 236 | self._saver = tf.train.Saver() 237 | 238 | if FLAGS.load_state and os.path.exists(self.get_checkpoint_path()): 239 | self._saver.restore(sess, self.get_checkpoint_path()) 240 | ut.print_info('Restored requested. Previous epoch: %d' % self.get_past_epochs(), color=31) 241 | 242 | # MAIN LOOP 243 | for current_epoch in xrange(epochs_to_train): 244 | 245 | feed, permutation = self._get_epoch_dataset() 246 | for _, batch in enumerate(feed): 247 | filter = batch[1][0] 248 | assert batch[1][0,0] == batch[1][-1,0] 249 | encoding, = sess.run([self._encode], feed_dict={self._input: batch[0]}) # 1.1 encode forward 250 | clamped_enc, vae_grad = _clamp(encoding, filter) # 1.2 # clamp 251 | 252 | sess.run(self._assign_clamped, feed_dict={self._clamped:clamped_enc}) 253 | reconstruction, loss, clamped_gradient, _ = sess.run( # 2.1 decode forward+backward 254 | [self._decode, self._decoder_loss, self._clamped_grad, self._train_decoder], 255 | feed_dict={self._clamped: clamped_enc, self._reconstruction: batch[0]}) 256 | 257 | declamped_grad = _declamp_grad(vae_grad, clamped_gradient, filter) # 2.2 prepare gradient 258 | _, step = sess.run( # 3.0 encode backward path 259 | [self._train_encoder, self._step], 260 | feed_dict={self._input: batch[0], self._encoding: encoding-declamped_grad}) # Profit 261 | 262 | self._register_batch(batch, encoding, reconstruction, loss) 263 | self._register_epoch(current_epoch, epochs_to_train, permutation, sess) 264 | self._writer = tf.train.SummaryWriter(FLAGS.logdir, sess.graph) 265 | meta = self._register_training() 266 | return meta, self._stats['epoch_accuracy'] 267 | 268 | 269 | 270 | if __name__ == '__main__': 271 | # FLAGS.load_from_checkpoint = './tmp/doom_bs__act|sigmoid__bs|20__h|500|5|500__init|na__inp|cbd4__lr|0.0004__opt|AO' 272 | epochs = 300 273 | import sys 274 | 275 | model = IGNModel() 276 | args = dict([arg.split('=', maxsplit=1) for arg in sys.argv[1:]]) 277 | if len(args) == 0: 278 | global DEV 279 | DEV = False 280 | print('DEVELOPMENT MODE ON') 281 | print(args) 282 | if 'epochs' in args: 283 | epochs = int(args['epochs']) 284 | ut.print_info('epochs: %d' % epochs, color=36) 285 | if 'sigma' in args: 286 | FLAGS.sigma = int(args['sigma']) 287 | if 'suffix' in args: 288 | FLAGS.suffix = args['suffix'] 289 | if 'input' in args: 290 | parts = FLAGS.input_path.split('/') 291 | parts[-3] = args['input'] 292 | FLAGS.input_path = '/'.join(parts) 293 | ut.print_info('input %s' % FLAGS.input_path, color=36) 294 | if 'h' in args: 295 | layers = list(map(int, args['h'].split('/'))) 296 | ut.print_info('layers %s' % str(layers), color=36) 297 | model.set_layer_sizes(layers) 298 | if 'divider' in args: 299 | FLAGS.drag_divider = float(args['divider']) 300 | if 'lr' in args: 301 | FLAGS.learning_rate = float(args['lr']) 302 | 303 | 304 | model.train(epochs) 305 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 yselivonchyk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Model.py: -------------------------------------------------------------------------------- 1 | """MNIST Autoencoder. """ 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import tensorflow as tf 8 | import json 9 | import os 10 | import numpy as np 11 | import utils as ut 12 | import input as inp 13 | import tools.checkpoint_utils as ch_utils 14 | import visualization as vis 15 | import matplotlib.pyplot as plt 16 | import time 17 | 18 | tf.app.flags.DEFINE_string('suffix', 'run', 'Suffix to use to distinguish models by purpose') 19 | tf.app.flags.DEFINE_string('input_path', '../data/tmp/grid03.14.c.tar.gz', 'input folder') 20 | tf.app.flags.DEFINE_string('test_path', '../data/tmp/grid03.14.c.tar.gz', 'test set folder') 21 | tf.app.flags.DEFINE_float('test_max', 10000, 'max numer of exampes in the test set') 22 | tf.app.flags.DEFINE_string('save_path', './tmp/checkpoint', 'Where to save the model checkpoints.') 23 | tf.app.flags.DEFINE_string('logdir', '', 'where to save logs.') 24 | tf.app.flags.DEFINE_string('load_from_checkpoint', None, 'Load model state from particular checkpoint') 25 | 26 | tf.app.flags.DEFINE_integer('max_epochs', 50, 'Train for at most this number of epochs') 27 | tf.app.flags.DEFINE_integer('epoch_size', 100, 'Number of batches per epoch') 28 | tf.app.flags.DEFINE_integer('test_size', 0, 'Number of test batches per epoch') 29 | tf.app.flags.DEFINE_integer('save_every', 250, 'Save model state every INT epochs') 30 | tf.app.flags.DEFINE_integer('save_encodings_every', 5, 'Save encoding and visualizations every') 31 | tf.app.flags.DEFINE_boolean('load_state', True, 'Load state if possible ') 32 | 33 | tf.app.flags.DEFINE_integer('batch_size', 128, 'Batch size') 34 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'Create visualization of ') 35 | 36 | tf.app.flags.DEFINE_float('dropout', 0.0, 'Dropout probability of pre-narrow units') 37 | 38 | tf.app.flags.DEFINE_float('blur', 5.0, 'Max sigma value for Gaussian blur applied to training set') 39 | tf.app.flags.DEFINE_integer('blur_decrease', 50000, 'Decrease image blur every X steps') 40 | 41 | tf.app.flags.DEFINE_boolean('dev', False, 'Indicate that model is in the development mode') 42 | 43 | FLAGS = tf.app.flags.FLAGS 44 | slim = tf.contrib.slim 45 | 46 | DEV = False 47 | 48 | 49 | def is_stopping_point(current_epoch, epochs_to_train, stop_every=None, stop_x_times=None, 50 | stop_on_last=True): 51 | if stop_on_last and current_epoch + 1 == epochs_to_train: 52 | return True 53 | if stop_x_times is not None: 54 | return current_epoch % np.ceil(epochs_to_train / float(FLAGS.vis_substeps)) == 0 55 | if stop_every is not None: 56 | return (current_epoch + 1) % stop_every == 0 57 | 58 | 59 | def get_variable(name): 60 | assert FLAGS.load_from_checkpoint 61 | var = ch_utils.load_variable(tf.train.latest_checkpoint(FLAGS.load_from_checkpoint), name) 62 | return var 63 | 64 | 65 | def get_every_dataset(): 66 | all_data = [x[0] for x in os.walk('../data/tmp_grey/') if 'img' in x[0]] 67 | print(all_data) 68 | return all_data 69 | 70 | 71 | class Model: 72 | model_id = 'base' 73 | dataset = None 74 | test_set = None 75 | 76 | _writer, _saver = None, None 77 | _dataset, _filters = None, None 78 | 79 | def get_layer_info(self): 80 | return [self.layer_encoder, self.layer_narrow, self.layer_decoder] 81 | 82 | # MODEL 83 | 84 | def build_model(self): 85 | pass 86 | 87 | def _build_encoder(self): 88 | pass 89 | 90 | def _build_decoder(self, weight_init=tf.truncated_normal): 91 | pass 92 | 93 | def _build_reco_loss(self, output_placeholder): 94 | error = self._decode - slim.flatten(output_placeholder) 95 | return tf.nn.l2_loss(error, name='reco_loss') 96 | 97 | def train(self, epochs_to_train=5): 98 | pass 99 | 100 | # META 101 | 102 | def get_meta(self, meta=None): 103 | meta = meta if meta else {} 104 | 105 | meta['postf'] = self.model_id 106 | meta['a'] = 's' 107 | meta['lr'] = FLAGS.learning_rate 108 | meta['init'] = self._weight_init 109 | meta['bs'] = FLAGS.batch_size 110 | meta['h'] = self.get_layer_info() 111 | meta['opt'] = self._optimizer 112 | meta['inp'] = inp.get_input_name(FLAGS.input_path) 113 | meta['do'] = FLAGS.dropout 114 | return meta 115 | 116 | def save_meta(self, meta=None): 117 | if meta is None: 118 | meta = self.get_meta() 119 | 120 | ut.configure_folders(FLAGS, meta) 121 | meta['a'] = 's' 122 | meta['opt'] = str(meta['opt']).split('.')[-1][:-2] 123 | meta['input_path'] = FLAGS.input_path 124 | path = os.path.join(FLAGS.save_path, 'meta.txt') 125 | json.dump(meta, open(path, 'w')) 126 | 127 | def load_meta(self, save_path): 128 | path = os.path.join(save_path, 'meta.txt') 129 | meta = json.load(open(path, 'r')) 130 | FLAGS.save_path = save_path 131 | FLAGS.batch_size = meta['bs'] 132 | FLAGS.input_path = meta['input_path'] 133 | FLAGS.learning_rate = meta['lr'] 134 | FLAGS.load_state = True 135 | FLAGS.dropout = float(meta['do']) 136 | return meta 137 | 138 | # DATA 139 | 140 | _blurred_dataset, _last_blur = None, 0 141 | 142 | def _get_blur_sigma(self, step=None): 143 | step = step if step is not None else self._current_step.eval() 144 | calculated_sigma = FLAGS.blur - int(10 * step / FLAGS.blur_decrease) / 10.0 145 | return max(0, calculated_sigma) 146 | 147 | def _get_blurred_dataset(self): 148 | if FLAGS.blur != 0: 149 | current_sigma = self._get_blur_sigma() 150 | if current_sigma != self._last_blur: 151 | self._last_blur = current_sigma 152 | self._blurred_dataset = inp.apply_gaussian(self.dataset, sigma=current_sigma) 153 | return self._blurred_dataset if self._blurred_dataset is not None else self.dataset 154 | 155 | # MISC 156 | 157 | def get_past_epochs(self): 158 | return int(self._current_step.eval() / FLAGS.epoch_size) 159 | 160 | @staticmethod 161 | def get_checkpoint_path(): 162 | return os.path.join(FLAGS.save_path, '-9999.chpt') 163 | 164 | # OUTPUTS 165 | @staticmethod 166 | def _get_stats_template(): 167 | return { 168 | 'batch': [], 169 | 'input': [], 170 | 'encoding': [], 171 | 'reconstruction': [], 172 | 'total_loss': 0, 173 | 'start': time.time() 174 | } 175 | 176 | _epoch_stats = None 177 | _stats = None 178 | 179 | @ut.timeit 180 | def restore_model(self, session): 181 | self._saver = tf.train.Saver() 182 | latest_checkpoint = tf.train.latest_checkpoint(self.get_checkpoint_path()[:-10]) 183 | ut.print_info("latest checkpoint: %s" % latest_checkpoint) 184 | if FLAGS.load_state and latest_checkpoint is not None: 185 | self._saver.restore(session, latest_checkpoint) 186 | ut.print_info('Restored requested. Previous epoch: %d' % self.get_past_epochs(), color=31) 187 | 188 | def _register_training_start(self, sess): 189 | self.summary_writer = tf.summary.FileWriter('/tmp/train', sess.graph) 190 | 191 | self._epoch_stats = self._get_stats_template() 192 | self._stats = { 193 | 'epoch_accuracy': [], 194 | 'epoch_reconstructions': [], 195 | 'permutation': None 196 | } 197 | 198 | if FLAGS.dev: 199 | plt.ion() 200 | plt.show() 201 | 202 | # @ut.timeit 203 | def _register_batch(self, loss, batch=None, encoding=None, reconstruction=None, step=None): 204 | self._epoch_stats['total_loss'] += loss 205 | if FLAGS.dev: 206 | assert batch is not None and reconstruction is not None 207 | original = batch[0][:, 0] 208 | vis.plot_reconstruction(original, reconstruction, interactive=True) 209 | 210 | MAX_IMAGES = 10 211 | 212 | # @ut.timeit 213 | def _register_epoch(self, epoch, total_epochs, elapsed, sess): 214 | if is_stopping_point(epoch, total_epochs, FLAGS.save_every): 215 | self._saver.save(sess, self.get_checkpoint_path()) 216 | 217 | accuracy = 100000 * np.sqrt(self._epoch_stats['total_loss'] / np.prod(self._batch_shape) / FLAGS.epoch_size) 218 | 219 | if is_stopping_point(epoch, total_epochs, FLAGS.save_encodings_every): 220 | digest = self.evaluate(sess, take=self.MAX_IMAGES) 221 | data = { 222 | 'enc': np.asarray(digest[0]), 223 | 'rec': np.asarray(digest[1]), 224 | 'blu': np.asarray(digest[2][:self.MAX_IMAGES]) 225 | } 226 | 227 | meta = {'suf': 'encodings', 'e': '%06d' % int(self.get_past_epochs()), 'er': int(accuracy)} 228 | projection_file = ut.to_file_name(meta, FLAGS.save_path) 229 | np.save(projection_file, data) 230 | vis.plot_encoding_crosssection(data['enc'], FLAGS.save_path, meta, data['blu'], data['rec']) 231 | 232 | self._stats['epoch_accuracy'].append(accuracy) 233 | self.print_epoch_info(accuracy, epoch, total_epochs, elapsed) 234 | if epoch + 1 != total_epochs: 235 | self._epoch_stats = self._get_stats_template() 236 | 237 | @ut.timeit 238 | def _register_training(self): 239 | best_acc = np.min(self._stats['epoch_accuracy']) 240 | meta = self.get_meta() 241 | meta['acu'] = int(best_acc) 242 | meta['e'] = self.get_past_epochs() 243 | ut.print_time('Best Quality: %f for %s' % (best_acc, ut.to_file_name(meta))) 244 | self.summary_writer.close() 245 | return meta 246 | 247 | def print_epoch_info(self, accuracy, current_epoch, epochs, elapsed): 248 | epochs_past = self.get_past_epochs() - current_epoch 249 | accuracy_info = '' if accuracy is None else '| accuracy %d' % int(accuracy) 250 | epoch_past_info = '' if epochs_past is None else '+%d' % (epochs_past - 1) 251 | epoch_count = 'Epochs %2d/%d%s' % (current_epoch + 1, epochs, epoch_past_info) 252 | time_info = '%2dms/bt' % (elapsed / FLAGS.epoch_size * 1000) 253 | 254 | info_string = ' '.join([ 255 | epoch_count, 256 | accuracy_info, 257 | time_info]) 258 | 259 | ut.print_time(info_string, same_line=True) 260 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DCIGN_tensorflow 2 | Deep Convolutional Inverse Graphics network (DCIGN) implementation with Tensorflow 3 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yselivonchyk/TensorFlow_DCIGN/ff8d85f3a7b7ca1e5c3f50ff003a1c09a70067cd/__init__.py -------------------------------------------------------------------------------- /activation_functions.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | sigmoid = type('Fake', (object,), { "func": tf.nn.sigmoid, "min": 0, 'max': 1}) 4 | tanh = type('Fake', (object,), { "func": tf.nn.tanh, "min": -1, 'max': 1}) 5 | relu = type('Fake', (object,), { "func": tf.nn.relu, "min": 0, 'max': 1}) 6 | -------------------------------------------------------------------------------- /autoencoder.py: -------------------------------------------------------------------------------- 1 | """MNIST Autoencoder. """ 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 9 | import numpy as np 10 | import tensorflow as tf 11 | import utils as ut 12 | import input as inp 13 | import visualization as vis 14 | import matplotlib.pyplot as plt 15 | import time 16 | import sys 17 | import getch 18 | import model_interpreter as interpreter 19 | import network_utils as nut 20 | import math 21 | from tensorflow.contrib.tensorboard.plugins import projector 22 | from Bunch import Bunch 23 | 24 | 25 | tf.app.flags.DEFINE_string('input_path', '../data/tmp/grid03.14.c.tar.gz', 'input folder') 26 | tf.app.flags.DEFINE_string('input_name', '', 'input folder') 27 | tf.app.flags.DEFINE_string('test_path', '', 'test set folder') 28 | tf.app.flags.DEFINE_string('net', 'f100-f3', 'model configuration') 29 | tf.app.flags.DEFINE_string('model', 'noise', 'Type of the model to use: Autoencoder (ae)' 30 | 'WhatWhereAe (ww) U-netAe (u)') 31 | tf.app.flags.DEFINE_string('postfix', '', 'Postfix for the training folder') 32 | 33 | tf.app.flags.DEFINE_float('alpha', 10, 'Predictive reconstruction loss weight') 34 | tf.app.flags.DEFINE_float('beta', 0.0005, 'Reconstruction from noisy data loss weight') 35 | tf.app.flags.DEFINE_float('epsilon', 0.000001, 36 | 'Diameter of epsilon sphere comparing to distance to a neighbour. <= 0.5') 37 | tf.app.flags.DEFINE_float('gamma', 50., 'Loss weight for large distances') 38 | tf.app.flags.DEFINE_float('distance', 0.01, 'Maximum allowed interpoint distance') 39 | tf.app.flags.DEFINE_float('delta', 1., 'Loss weight for stacked objective') 40 | 41 | tf.app.flags.DEFINE_string('comment', '', 'Comment to leave by the model') 42 | 43 | tf.app.flags.DEFINE_float('test_max', 10000, 'max number of examples in the test set') 44 | 45 | tf.app.flags.DEFINE_integer('max_epochs', 0, 'Train for at most this number of epochs') 46 | tf.app.flags.DEFINE_integer('save_every', 250, 'Save model state every INT epochs') 47 | tf.app.flags.DEFINE_integer('eval_every', 25, 'Save encoding and visualizations every') 48 | tf.app.flags.DEFINE_integer('visualiza_max', 10, 'Max pairs to show on visualization') 49 | tf.app.flags.DEFINE_boolean('load_state', True, 'Load state if possible ') 50 | tf.app.flags.DEFINE_boolean('kill_depth', False, 'Ignore depth information') 51 | tf.app.flags.DEFINE_boolean('dev', False, 'Indicate development mode') 52 | tf.app.flags.DEFINE_integer('batch_size', 128, 'Batch size') 53 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'Create visualization of ') 54 | 55 | tf.app.flags.DEFINE_float('blur', 5.0, 'Max sigma value for Gaussian blur applied to training set') 56 | tf.app.flags.DEFINE_boolean('new_blur', False, 'Use data augmentation as blur info') 57 | tf.app.flags.DEFINE_integer('blur_decrease', 10000, 'Decrease image blur every X steps') 58 | 59 | FLAGS = tf.app.flags.FLAGS 60 | slim = tf.contrib.slim 61 | 62 | 63 | AUTOENCODER = 'ae' 64 | PREDICTIVE = 'pred' 65 | DENOISING = 'noise' 66 | 67 | CHECKPOINT_NAME = '-9999.chpt' 68 | EMB_SUFFIX = '_embedding' 69 | 70 | 71 | def is_stopping_point(current_epoch, epochs_to_train, stop_every=None, stop_x_times=None, 72 | stop_on_last=True): 73 | if stop_on_last and current_epoch + 1 == epochs_to_train: 74 | return True 75 | if stop_x_times is not None: 76 | return current_epoch % np.ceil(epochs_to_train / float(stop_x_times)) == 0 77 | if stop_every is not None: 78 | return (current_epoch + 1) % stop_every == 0 79 | 80 | 81 | def _fetch_dataset(path, take=None): 82 | dataset = inp.read_ds_zip(path) # read 83 | take = len(dataset) if take is None else take 84 | dataset = dataset[:take] 85 | # print(dataset.dtype, dataset.shape, np.min(dataset), np.max(dataset)) 86 | # dataset = inp.rescale_ds(dataset, 0, 1) 87 | if FLAGS.kill_depth: 88 | dataset[..., -1] = 0 89 | ut.print_info('DS fetch: %8d (%s)' % (len(dataset), path)) 90 | return dataset 91 | 92 | 93 | def l2(x): 94 | l = x.get_shape().as_list()[0] 95 | return tf.reshape(tf.sqrt(tf.reduce_sum(x ** 2, axis=1)), (l, 1)) 96 | 97 | 98 | def get_stats_template(): 99 | return Bunch( 100 | batch=[], 101 | input=[], 102 | encoding=[], 103 | reconstruction=[], 104 | total_loss=0., 105 | start=time.time()) 106 | 107 | 108 | def guard_nan(x): 109 | return x if not math.isnan(x) else -1. 110 | 111 | 112 | def _blur_expand(input): 113 | k_size = 9 114 | kernels = [2, 4, 6] 115 | channels = [input] + [nut.blur_gaussian(input, k, k_size)[0] for k in kernels] 116 | res = tf.concat(channels, axis=3) 117 | return res 118 | 119 | 120 | class Autoencoder: 121 | train_set, test_set = None, None 122 | permutation = None 123 | batch_shape = None 124 | epoch_size = None 125 | 126 | input, target = None, None # AE placeholders 127 | encode, decode = None, None # AE operations 128 | model = None # interpreted model 129 | 130 | encoding = None # AE predictive evaluation placeholder 131 | eval_decode, eval_loss = None, None # AE evaluation 132 | 133 | inputs, targets = None, None # Noise/Predictive placeholders 134 | raw_inputs, raw_targets = None, None # inputs in network-friendly representation 135 | models = None # Noise/Predictive interpreted models 136 | 137 | optimizer, _train = None, None 138 | loss_ae, loss_reco, loss_pred, loss_dn = None, None, None, None # Objectives 139 | loss_total = None 140 | losses = [] 141 | 142 | step = None # operation 143 | step_var = None # variable 144 | 145 | vis_summary, vis_placeholder = None, None 146 | image_summaries = None 147 | visualization_batch_perm = None 148 | 149 | 150 | def __init__(self, optimizer=tf.train.AdamOptimizer, need_forlders=True): 151 | self.optimizer_constructor = optimizer 152 | FLAGS.input_name = inp.get_input_name(FLAGS.input_path) 153 | if need_forlders: 154 | ut.configure_folders(FLAGS) 155 | ut.print_flags(FLAGS) 156 | 157 | # MISC 158 | 159 | 160 | def get_past_epochs(self): 161 | return int(self.step.eval() / self.epoch_size) 162 | 163 | @staticmethod 164 | def get_checkpoint_path(): 165 | # print(os.path.join(FLAGS.save_path, CHECKPOINT_NAME), len(CHECKPOINT_NAME)) 166 | return os.path.join(FLAGS.save_path, CHECKPOINT_NAME) 167 | 168 | def get_latest_checkpoint(self): 169 | return tf.train.latest_checkpoint( 170 | self.get_checkpoint_path()[:-len(EMB_SUFFIX)], 171 | latest_filename='checkpoint' 172 | ) 173 | 174 | 175 | # DATA 176 | 177 | 178 | def fetch_datasets(self): 179 | if FLAGS.max_epochs == 0: 180 | FLAGS.input_path = FLAGS.test_path 181 | self.train_set = _fetch_dataset(FLAGS.input_path) 182 | self.epoch_size = int(self.train_set.shape[0] / FLAGS.batch_size) 183 | self.batch_shape = [FLAGS.batch_size] + list(self.train_set.shape[1:]) 184 | 185 | reuse_train = FLAGS.test_path == FLAGS.input_path or FLAGS.test_path == '' 186 | self.test_set = self.train_set.copy() if reuse_train else _fetch_dataset(FLAGS.test_path) 187 | take_test = int(FLAGS.test_max) if FLAGS.test_max > 1 else int(FLAGS.test_max * len(self.test_set)) 188 | ut.print_info('take %d from test' % take_test) 189 | self.test_set = self.test_set[:take_test] 190 | 191 | def _batch_generator(self, x=None, y=None, shuffle=True, batches=None): 192 | """Returns BATCH_SIZE of couples of subsequent images""" 193 | x = x if x is not None else self._get_blurred_dataset() 194 | y = y if y is not None else x 195 | batches = batches if batches is not None else int(np.floor(len(x) / FLAGS.batch_size)) 196 | self.permutation = np.arange(len(x)) 197 | self.permutation = self.permutation if not shuffle else np.random.permutation(self.permutation) 198 | 199 | for i in range(batches): 200 | batch_indexes = self.permutation[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size] 201 | # batch = np.stack((dataset[batch_indexes], dataset[batch_indexes + 1], dataset[batch_indexes + 2]), axis=1) 202 | yield x[batch_indexes], y[batch_indexes] 203 | 204 | def _batch_permutation_generator(self, length, start=0, shuffle=True, batches=None): 205 | self.permutation = np.arange(length) + start 206 | self.permutation = self.permutation if not shuffle else np.random.permutation(self.permutation) 207 | for i in range(int(length/FLAGS.batch_size)): 208 | if batches is not None and i >= batches: 209 | break 210 | yield self.permutation[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size] 211 | 212 | _blurred_dataset, _last_blur = None, 0 213 | 214 | def _get_blur_sigma(self): 215 | calculated_sigma = FLAGS.blur - int(10 * self.step.eval() / FLAGS.blur_decrease) / 10.0 216 | return max(0, calculated_sigma) 217 | 218 | # @ut.timeit 219 | def _get_blurred_dataset(self): 220 | if FLAGS.blur != 0: 221 | current_sigma = self._get_blur_sigma() 222 | if current_sigma != self._last_blur: 223 | # print(self._last_blur, current_sigma) 224 | self._last_blur = current_sigma 225 | self._blurred_dataset = inp.apply_gaussian(self.train_set, sigma=current_sigma) 226 | ut.print_info('blur s:%.1f[%.1f>%.1f]' % (current_sigma, self.train_set[2, 10, 10, 0], self._blurred_dataset[2, 10, 10, 0])) 227 | return self._blurred_dataset if self._blurred_dataset is not None else self.train_set 228 | return self.train_set 229 | 230 | 231 | # TRAIN 232 | 233 | 234 | def build_ae_model(self): 235 | self.input = tf.placeholder(tf.uint8, self.batch_shape, name='input') 236 | self.target = tf.placeholder(tf.uint8, self.batch_shape, name='target') 237 | self.step = tf.Variable(0, trainable=False, name='global_step') 238 | root = self._image_to_tensor(self.input) 239 | target = self._image_to_tensor(self.target) 240 | 241 | model = interpreter.build_autoencoder(root, FLAGS.net) 242 | 243 | self.encode = model.encode 244 | 245 | self.model = model 246 | self.encoding = tf.placeholder(self.encode.dtype, self.encode.get_shape(), name='encoding') 247 | eval_decode = interpreter.build_decoder(self.encoding, model.config, reuse=True) 248 | print(target, eval_decode) 249 | self.eval_loss = interpreter.l2_loss(target, eval_decode, name='predictive_reconstruction') 250 | self.eval_decode = self._tensor_to_image(eval_decode) 251 | 252 | self.loss_ae = interpreter.l2_loss(target, model.decode, name='reconstruction') 253 | self.decode = self._tensor_to_image(model.decode) 254 | self.losses = [self.loss_ae] 255 | 256 | def build_predictive_model(self): 257 | self.build_ae_model() # builds on top of AE model. Due to auxilary operations init 258 | self.inputs = tf.placeholder(tf.uint8, [3] + self.batch_shape, name='inputs') 259 | self.targets = tf.placeholder(tf.uint8, [3] + self.batch_shape, name='targets') 260 | 261 | # transform inputs 262 | self.raw_inputs = [self._image_to_tensor(self.inputs[i]) for i in range(3)] 263 | self.raw_targets = [self._image_to_tensor(self.targets[i]) for i in range(3)] 264 | 265 | # build AE objective for triplet 266 | config = self.model.config 267 | models = [interpreter.build_autoencoder(x, config) for x in self.raw_inputs] 268 | reco_losses = [1./3 * interpreter.l2_loss(models[i].decode, self.raw_targets[i]) for i in range(3)] # business as usual 269 | self.models = models 270 | 271 | # build predictive objective 272 | pred_loss_2 = self._prediction_decode(models[1].encode*2 - models[0].encode, self.raw_targets[2], models[2]) 273 | pred_loss_0 = self._prediction_decode(models[1].encode*2 - models[2].encode, self.raw_targets[0], models[0]) 274 | 275 | # build regularized distance objective 276 | dist_loss1 = self._distance_loss(models[1].encode - models[0].encode) 277 | dist_loss2 = self._distance_loss(models[1].encode - models[2].encode) 278 | 279 | # Stitch it all together and train 280 | self.loss_reco = tf.add_n(reco_losses) 281 | self.loss_pred = pred_loss_0 + pred_loss_2 282 | self.loss_dist = dist_loss1 + dist_loss2 283 | self.losses = [self.loss_reco, self.loss_pred] 284 | 285 | def _distance_loss(self, distances): 286 | error = tf.nn.relu(l2(distances) - FLAGS.distance ** 2) 287 | return tf.reduce_sum(error) 288 | 289 | def _prediction_decode(self, prediction, target, model): 290 | """Predict encoding t3 by encoding (t2 and t1) and expect a good reconstruction""" 291 | predict_decode = interpreter.build_decoder(prediction, self.model.config, reuse=True, masks=model.mask_list) 292 | predict_loss = 1./2 * interpreter.l2_loss(predict_decode, target, alpha=FLAGS.alpha) 293 | self.models += [predict_decode] 294 | return predict_loss * FLAGS.gamma 295 | 296 | 297 | def build_denoising_model(self): 298 | self.build_predictive_model() # builds on top of predictive model. Reuses triplet encoding 299 | 300 | # build denoising objective 301 | models = self.models 302 | self.loss_dn = self._noisy_decode(models[1]) 303 | self.losses = [self.loss_reco, self.loss_pred, self.loss_dist, self.loss_dn] 304 | 305 | def _noisy_decode(self, model): 306 | """Distort middle encoding with [<= 1/3*dist(neigbour)] and demand good reconstruction""" 307 | # dist = l2(x1 - x2) 308 | # noise = dist * self.epsilon_sphere_noise() 309 | # tf.stop_gradient(noise) 310 | noise = tf.random_normal(self.model.encode.get_shape().as_list()) * FLAGS.epsilon 311 | noisy_encoding = noise + self.models[1].encode 312 | tf.stop_gradient(noisy_encoding) # or maybe here, who knows 313 | noisy_decode = interpreter.build_decoder(noisy_encoding, model.config, reuse=True, masks=model.mask_list) 314 | loss = interpreter.l2_loss(noisy_decode, self.raw_targets[1], alpha=FLAGS.beta) 315 | self.models += [noisy_decode] 316 | return loss 317 | 318 | def _tensor_to_image(self, net): 319 | with tf.name_scope('to_image'): 320 | if FLAGS.new_blur: 321 | net = net[..., :self.batch_shape[-1]] 322 | net = tf.nn.relu(net) 323 | net = tf.cast(net <= 1, net.dtype) * net * 255 324 | net = tf.cast(net, tf.uint8) 325 | return net 326 | 327 | def _image_to_tensor(self, image): 328 | with tf.name_scope('args_transform'): 329 | net = tf.cast(image, tf.float32) / 255. 330 | if FLAGS.new_blur: 331 | net = _blur_expand(net) 332 | FLAGS.blur = 0. 333 | return net 334 | 335 | def _init_optimizer(self): 336 | self.loss_total = tf.add_n(self.losses, 'loss_total') 337 | self.optimizer = self.optimizer_constructor(learning_rate=FLAGS.learning_rate) 338 | self._train = self.optimizer.minimize(self.loss_total, global_step=self.step) 339 | 340 | 341 | # MAIN 342 | 343 | 344 | def train(self): 345 | self.fetch_datasets() 346 | if FLAGS.model == AUTOENCODER: 347 | self.build_ae_model() 348 | elif FLAGS.model == PREDICTIVE: 349 | self.build_predictive_model() 350 | else: 351 | self.build_denoising_model() 352 | self._init_optimizer() 353 | 354 | with tf.Session() as sess: 355 | sess.run(tf.global_variables_initializer()) 356 | self._on_training_start(sess) 357 | 358 | try: 359 | for current_epoch in range(FLAGS.max_epochs): 360 | start = time.time() 361 | full_set_blur = len(self.train_set) < 50000 362 | ds = self._get_blurred_dataset() if full_set_blur else self.train_set 363 | if FLAGS.model == AUTOENCODER: 364 | 365 | # Autoencoder Training 366 | for batch in self._batch_generator(): 367 | summs, encoding, reconstruction, loss, _, step = sess.run( 368 | [self.summs_train, self.encode, self.decode, self.loss_ae, self.train_ae, self.step], 369 | feed_dict={self.input: batch[0], self.target: batch[1]} 370 | ) 371 | self._on_batch_finish(summs, loss, batch, encoding, reconstruction) 372 | 373 | else: 374 | 375 | # Predictive and Denoising training 376 | for batch_indexes in self._batch_permutation_generator(len(ds)-2): 377 | batch = np.stack((ds[batch_indexes], ds[batch_indexes + 1], ds[batch_indexes + 2])) 378 | if not full_set_blur: 379 | batch = np.stack(( 380 | inp.apply_gaussian(ds[batch_indexes], sigma=self._get_blur_sigma()), 381 | inp.apply_gaussian(ds[batch_indexes+1], sigma=self._get_blur_sigma()), 382 | inp.apply_gaussian(ds[batch_indexes+2], sigma=self._get_blur_sigma()) 383 | )) 384 | 385 | summs, loss, _ = sess.run( 386 | [self.summs_train, self.loss_total, self._train], 387 | feed_dict={self.inputs: batch, self.targets: batch}) 388 | self._on_batch_finish(summs, loss) 389 | 390 | self._on_epoch_finish(current_epoch, start, sess) 391 | self._on_training_finish(sess) 392 | except KeyboardInterrupt: 393 | self._on_training_abort(sess) 394 | 395 | def inference(self, max=10^6): 396 | self.fetch_datasets() 397 | self.build_ae_model() 398 | 399 | with tf.Session() as sess: 400 | sess.run(tf.global_variables_initializer()) 401 | # nut.print_model_info() 402 | # nut.list_checkpoint_vars(self.get_latest_checkpoint().replace(EMB_SUFFIX, '')) 403 | 404 | self.saver = tf.train.Saver() 405 | self._restore_model(sess) 406 | # nut.print_model_info() 407 | 408 | encoding, decoding = None, None 409 | for i in range(len(self.train_set)): 410 | batch = np.expand_dims(self.train_set[i], axis=0) 411 | enc, dec = sess.run( 412 | [self.encode, self.decode], 413 | feed_dict={self.input: batch} 414 | ) 415 | 416 | # print(enc.shape, dec.shape) 417 | encoding = enc if i == 0 else np.vstack((encoding, enc)) 418 | decoding = dec if i == 0 else np.vstack((decoding, dec)) 419 | print('\r%5d/%d' % (i, len(self.train_set)), end='') 420 | if i >= max: 421 | break 422 | return encoding, decoding 423 | 424 | # @ut.timeit 425 | def evaluate(self, sess, take): 426 | digest = Bunch(encoded=None, reconstructed=None, source=None, 427 | loss=.0, eval_loss=.0, dumb_loss=.0) 428 | blurred = inp.apply_gaussian(self.test_set, self._get_blur_sigma()) 429 | # Encode 430 | for i, batch in enumerate(self._batch_generator(blurred, shuffle=False)): 431 | encoding = self.encode.eval(feed_dict={self.input: batch[0]}) 432 | digest.encoded = ut.concatenate(digest.encoded, encoding) 433 | # Save encoding for visualization 434 | encoded_no_nan = np.nan_to_num(digest.encoded) 435 | self.embedding_assign.eval(feed_dict={self.embedding_test_ph: encoded_no_nan}) 436 | try: 437 | self.embedding_saver.save(sess, self.get_checkpoint_path() + EMB_SUFFIX) 438 | except: 439 | ut.print_info("Unexpected error: %s" % str(sys.exc_info()[0]), color=33) 440 | 441 | # Calculate expected evaluation 442 | expected = digest.encoded[1:-1]*2 - digest.encoded[:-2] 443 | average = 0.5 * (digest.encoded[1:-1] + digest.encoded[:-2]) 444 | digest.size = len(expected) 445 | # evaluation summaries 446 | self.summary_writer.add_summary(self.eval_summs.eval( 447 | feed_dict={self.blur_ph: self._get_blur_sigma()}), 448 | global_step=self.get_past_epochs()) 449 | # evaluation losses 450 | for p in self._batch_permutation_generator(digest.size, shuffle=False): 451 | digest.loss += self.eval_loss.eval(feed_dict={self.encoding: digest.encoded[p + 2], self.target: blurred[p + 2]}) 452 | digest.eval_loss += self.eval_loss.eval(feed_dict={self.encoding: expected[p], self.target: blurred[p + 2]}) 453 | digest.dumb_loss += self.loss_ae.eval( feed_dict={self.input: blurred[p], self.target: blurred[p + 2]}) 454 | 455 | # for batch in self._batch_generator(blurred, batches=1): 456 | # digest.source = batch[1][:take] 457 | # digest.reconstructed = self.decode.eval(feed_dict={self.input: batch[0]})[:take] 458 | 459 | # Reconstruction visualizations 460 | for p in self._batch_permutation_generator(digest.size, shuffle=True, batches=1): 461 | self.visualization_batch_perm = self.visualization_batch_perm if self.visualization_batch_perm is not None else p 462 | p = self.visualization_batch_perm 463 | digest.source = self.eval_decode.eval(feed_dict={self.encoding: expected[p]})[:take] 464 | digest.source = blurred[(p+2)[:take]] 465 | digest.reconstructed = self.eval_decode.eval(feed_dict={self.encoding: average[p]})[:take] 466 | self._eval_image_summaries(blurred[p], digest.encoded[p], average[p], expected[p]) 467 | 468 | digest.dumb_loss = guard_nan(digest.dumb_loss) 469 | digest.eval_loss = guard_nan(digest.eval_loss) 470 | digest.loss = guard_nan(digest.loss) 471 | return digest 472 | 473 | def _eval_image_summaries(self, blurred_batch, actual, average, expected): 474 | """Create Tensorboard summaries with image reconstructions""" 475 | noisy = expected + np.random.randn(*expected.shape) * FLAGS.epsilon 476 | 477 | summary = self.image_summaries['orig'].eval(feed_dict={self.input: blurred_batch}) 478 | self.summary_writer.add_summary(summary, global_step=self.get_past_epochs()) 479 | 480 | self._eval_image_summary('midd', average) 481 | # self._eval_image_summary('reco', actual) 482 | self._eval_image_summary('pred', expected) 483 | self._eval_image_summary('nois', noisy) 484 | 485 | def _eval_image_summary(self, name, encdoding_batch): 486 | summary = self.image_summaries[name].eval(feed_dict={self.encoding: encdoding_batch}) 487 | self.summary_writer.add_summary(summary, global_step=self.get_past_epochs()) 488 | 489 | def _add_decoding_summary(self, name, var, collection='train'): 490 | var = var[:FLAGS.visualiza_max] 491 | var = tf.concat(tf.unstack(var), axis=0) 492 | var = tf.expand_dims(var, dim=0) 493 | color_s = tf.summary.image(name, var[..., :3], max_outputs=FLAGS.visualiza_max) 494 | var = tf.expand_dims(var[..., 3], dim=3) 495 | bw_s = tf.summary.image('depth_' + name, var, max_outputs=FLAGS.visualiza_max) 496 | return tf.summary.merge([color_s, bw_s]) 497 | 498 | 499 | # TRAINING PROGRESS EVENTS 500 | 501 | 502 | def _on_training_start(self, sess): 503 | # Writers and savers 504 | self.summary_writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph) 505 | self.saver = tf.train.Saver() 506 | self._build_embedding_saver(sess) 507 | self._restore_model(sess) 508 | # Loss summaries 509 | self._build_summaries() 510 | 511 | self.epoch_stats = get_stats_template() 512 | self.stats = Bunch( 513 | epoch_accuracy=[], 514 | epoch_reconstructions=[], 515 | permutation=None 516 | ) 517 | # if FLAGS.dev: 518 | # plt.ion() 519 | # plt.show() 520 | 521 | def _build_summaries(self): 522 | # losses 523 | with tf.name_scope('losses'): 524 | loss_names = ['loss_autoencoder', 'loss_predictive', 'loss_distance', 'loss_denoising'] 525 | for i, loss in enumerate(self.losses): 526 | self._add_loss_summary(loss_names[i], loss) 527 | self._add_loss_summary('loss_total', self.loss_total) 528 | self.summs_train = tf.summary.merge_all('train') 529 | # reconstructions 530 | with tf.name_scope('decodings'): 531 | self.image_summaries = { 532 | 'orig': self._add_decoding_summary('0_original_input', self.input), 533 | 'reco': self._add_decoding_summary('1_reconstruction', self.eval_decode), 534 | 'pred': self._add_decoding_summary('2_prediction', self.eval_decode), 535 | 'midd': self._add_decoding_summary('3_averaged', self.eval_decode), 536 | 'nois': self._add_decoding_summary('4_noisy', self.eval_decode) 537 | } 538 | # visualization 539 | fig = vis.get_figure() 540 | fig.canvas.draw() 541 | self.vis_placeholder = tf.placeholder(tf.uint8, ut.fig2rgb_array(fig).shape) 542 | self.vis_summary = tf.summary.image('visualization', self.vis_placeholder) 543 | # embedding 544 | dists = l2(self.embedding_test[:-1] - self.embedding_test[1:]) 545 | self.dist = dists 546 | metrics = [] 547 | 548 | metrics.append(tf.summary.histogram('point_distance', dists)) 549 | metrics.append(tf.summary.scalar('training/trajectory_length', tf.reduce_sum(dists))) 550 | self.blur_ph = tf.placeholder(dtype=tf.float32) 551 | metrics.append(tf.summary.scalar('training/blur_sigma', self.blur_ph)) 552 | 553 | pred = self.embedding_test[1:-1]*2 - self.embedding_test[0:-2] 554 | pred_error = l2(pred - self.embedding_test[2:]) 555 | 556 | mean_dist, mean_pred_error = tf.reduce_mean(dists), tf.reduce_mean(pred_error) 557 | improvement = (mean_dist-mean_pred_error)/mean_dist 558 | 559 | pairwise_improvement = tf.nn.relu(dists[1:] - pred_error) 560 | pairwise_improvement_bool = tf.cast(pairwise_improvement > 0, pairwise_improvement.dtype) 561 | self.pairwise_improvement_bool = pairwise_improvement_bool 562 | 563 | metrics.append(tf.summary.scalar('training/avg_dist', mean_dist)) 564 | metrics.append(tf.summary.scalar('training/pred_dist', mean_pred_error)) 565 | metrics.append(tf.summary.scalar('training/improvement', improvement)) 566 | metrics.append(tf.summary.scalar('training/improvement_abs', tf.nn.relu(improvement))) 567 | metrics.append(tf.summary.histogram('training/improvement_abs_hist', nut.nan_to_zero(improvement))) 568 | metrics.append(tf.summary.scalar('training/improvement_pairwise', tf.reduce_mean(pairwise_improvement_bool))) 569 | metrics.append(tf.summary.histogram('training/improvement_pairwise_hist', pairwise_improvement_bool)) 570 | self.eval_summs = tf.summary.merge(metrics) 571 | 572 | 573 | def _build_embedding_saver(self, sess): 574 | """To use embedding visualizer data has to be stored in variable 575 | since we would like to visualize TEST_SET, this variable should not affect 576 | common checkpoint of the model. 577 | Hence, we build a separate variable with a separate saver.""" 578 | embedding_shape = [int(len(self.test_set) / FLAGS.batch_size) * FLAGS.batch_size, 579 | self.encode.get_shape().as_list()[1]] 580 | tsv_path = os.path.join(FLAGS.logdir, 'metadata.tsv') 581 | 582 | self.embedding_test_ph = tf.placeholder(tf.float32, embedding_shape, name='embedding') 583 | self.embedding_test = tf.Variable(tf.random_normal(embedding_shape), name='test_embedding', trainable=False) 584 | self.embedding_assign = self.embedding_test.assign(self.embedding_test_ph) 585 | self.embedding_saver = tf.train.Saver(var_list=[self.embedding_test]) 586 | 587 | config = projector.ProjectorConfig() 588 | embedding = config.embeddings.add() 589 | embedding.tensor_name = self.embedding_test.name 590 | embedding.sprite.image_path = './sprite.png' 591 | embedding.sprite.single_image_dim.extend([80, 80]) 592 | embedding.metadata_path = './metadata.tsv' 593 | projector.visualize_embeddings(self.summary_writer, config) 594 | sess.run(tf.variables_initializer([self.embedding_test], name='init_embeddings')) 595 | 596 | # build sprite image 597 | ut.images_to_sprite(self.test_set, path=os.path.join(FLAGS.logdir, 'sprite.png')) 598 | ut.generate_tsv(len(self.test_set), tsv_path) 599 | 600 | def _add_loss_summary(self, name, var, collection='train'): 601 | if var is not None: 602 | tf.summary.scalar(name, var, [collection]) 603 | tf.summary.scalar('log_' + name, tf.log(var), [collection]) 604 | 605 | def _restore_model(self, session): 606 | latest_checkpoint = self.get_latest_checkpoint() 607 | print(latest_checkpoint) 608 | if latest_checkpoint is not None: 609 | latest_checkpoint = latest_checkpoint.replace(EMB_SUFFIX, '') 610 | ut.print_info("latest checkpoint: %s" % latest_checkpoint) 611 | if FLAGS.load_state and latest_checkpoint is not None: 612 | self.saver.restore(session, latest_checkpoint) 613 | ut.print_info('Restored requested. Previous epoch: %d' % self.get_past_epochs(), color=31) 614 | 615 | def _on_batch_finish(self, summs, loss, batch=None, encoding=None, reconstruction=None): 616 | self.summary_writer.add_summary(summs, global_step=self.step.eval()) 617 | self.epoch_stats.total_loss += loss 618 | 619 | if False: 620 | assert batch is not None and reconstruction is not None 621 | original = batch[0] 622 | vis.plot_reconstruction(original, reconstruction, interactive=True) 623 | 624 | # @ut.timeit 625 | def _on_epoch_finish(self, epoch, start_time, sess): 626 | elapsed = time.time() - start_time 627 | self.epoch_stats.total_loss = guard_nan(self.epoch_stats.total_loss) 628 | accuracy = np.nan_to_num(100000 * np.sqrt(self.epoch_stats.total_loss / np.prod(self.batch_shape) / self.epoch_size)) 629 | # SAVE 630 | if is_stopping_point(epoch, FLAGS.max_epochs, FLAGS.save_every): 631 | self.saver.save(sess, self.get_checkpoint_path()) 632 | # VISUALIZE 633 | if is_stopping_point(epoch, FLAGS.max_epochs, FLAGS.eval_every): 634 | evaluation = self.evaluate(sess, take=FLAGS.visualiza_max) 635 | data = { 636 | 'enc': np.asarray(evaluation.encoded), 637 | 'rec': np.asarray(evaluation.reconstructed), 638 | 'blu': np.asarray(evaluation.source) 639 | } 640 | error_info = '%d(%d.%d.%d)' % (np.nan_to_num(accuracy), 641 | np.nan_to_num(evaluation.loss)/evaluation.size, 642 | np.nan_to_num(evaluation.eval_loss)/evaluation.size, 643 | np.nan_to_num(evaluation.dumb_loss)/evaluation.size) 644 | meta = Bunch(suf='encodings', e='%06d' % int(self.get_past_epochs()), er=error_info) 645 | # print(data, meta.to_file_name(folder=FLAGS.save_path)) 646 | np.save(meta.to_file_name(folder=FLAGS.save_path), data) 647 | vis.plot_encoding_crosssection( 648 | evaluation.encoded, 649 | meta.to_file_name(FLAGS.save_path, 'jpg'), 650 | evaluation.source, 651 | evaluation.reconstructed, 652 | interactive=FLAGS.dev) 653 | self._save_visualization_to_summary() 654 | self.stats.epoch_accuracy.append(accuracy) 655 | self._print_epoch_info(accuracy, epoch, FLAGS.max_epochs, elapsed) 656 | if epoch + 1 != FLAGS.max_epochs: 657 | self.epoch_stats = get_stats_template() 658 | 659 | def _save_visualization_to_summary(self): 660 | image = ut.fig2rgb_array(plt.figure(num=0)) 661 | self.summary_writer.add_summary(self.vis_summary.eval(feed_dict={self.vis_placeholder: image})) 662 | 663 | def _print_epoch_info(self, accuracy, current_epoch, epochs, elapsed): 664 | epochs_past = self.get_past_epochs() - current_epoch 665 | accuracy_info = '' if accuracy is None else '| accuracy %d' % int(accuracy) 666 | epoch_past_info = '' if epochs_past is None else '+%d' % (epochs_past - 1) 667 | epoch_count = 'Epochs %2d/%d%s' % (current_epoch + 1, epochs, epoch_past_info) 668 | time_info = '%2dms/bt' % (elapsed / self.epoch_size * 1000) 669 | 670 | examples = int(np.floor(len(self.train_set) / FLAGS.batch_size)) 671 | loss_info = 't.loss:%d' % (self.epoch_stats.total_loss * 100 / (examples * np.prod(self.batch_shape[1:]))) 672 | 673 | info_string = ' '.join([epoch_count, accuracy_info, time_info, loss_info]) 674 | ut.print_time(info_string, same_line=True) 675 | 676 | def _on_training_finish(self, sess): 677 | if FLAGS.max_epochs == 0: 678 | self._on_epoch_finish(self.get_past_epochs(), time.time(), sess) 679 | best_acc = np.min(self.stats.epoch_accuracy) 680 | ut.print_time('Best Quality: %f for %s' % (best_acc, FLAGS.net)) 681 | self.summary_writer.close() 682 | 683 | def _on_training_abort(self, sess): 684 | print('Press ENTER to save the model') 685 | if getch.getch() == '\n': 686 | print('saving') 687 | self.saver.save(sess, self.get_checkpoint_path()) 688 | 689 | 690 | if __name__ == '__main__': 691 | args = dict([arg.split('=', maxsplit=1) for arg in sys.argv[1:]]) 692 | if len(args) <= 1: 693 | FLAGS.input_path = '../data/tmp/romb8.5.6.tar.gz' 694 | FLAGS.test_path = '../data/tmp/romb8.5.6.tar.gz' 695 | FLAGS.test_max = 2178 696 | FLAGS.max_epochs = 5 697 | FLAGS.eval_every = 1 698 | FLAGS.save_every = 1 699 | FLAGS.batch_size = 32 700 | FLAGS.blur = 0.0 701 | 702 | # FLAGS.model = 'noise' 703 | # FLAGS.beta = 1.0 704 | # FLAGS.epsilon = .000001 705 | 706 | model = Autoencoder() 707 | if FLAGS.model == 'ae': 708 | FLAGS.model = AUTOENCODER 709 | elif 'pred' in FLAGS.model: 710 | print('PREDICTIVE') 711 | FLAGS.model = PREDICTIVE 712 | elif 'noi' in FLAGS.model: 713 | print('DENOISING') 714 | FLAGS.model = DENOISING 715 | else: 716 | print('Do-di-li-doo doo-di-li-don') 717 | model.train() 718 | -------------------------------------------------------------------------------- /data_preprocessing/generate.command: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir tmp 3 | find . -type d -depth 2 -print | grep -v 'tmp' | cpio -pd ./tmp 4 | 5 | up=$"../../../tmp/" 6 | 7 | find . -type d -depth 3 -print | grep -v 'tmp' | while read d; do 8 | stringA=$up$d 9 | len=${#stringA} 10 | len=$((len-7)) 11 | echo ${stringA} 12 | stringA=${stringA::len} 13 | (cd $d && mogrify -path $stringA -resize 48x32 -crop 32x32+0+0 -quality 100 *.jpg) 14 | done 15 | -------------------------------------------------------------------------------- /data_preprocessing/movie.command: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mencoder "mf://*.jpg" -mf fps=5 -o out.avi -ovc lavc -lavcopts vco dec=msmpeg4v2:vbitrate=640 3 | -------------------------------------------------------------------------------- /data_preprocessing/resize.command: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #-type Grayscale 3 | #cd $d && mogrify -path $stringA -type Grayscale -resize 160x120 -crop 80x80+40+0 -quality 100 *.jpg 4 | rm -R tmp 5 | mkdir tmp 6 | find . -mindepth 2 -maxdepth 2 -type d -print | grep -v 'tmp' | cpio -pd ./tmp 7 | 8 | up=$"../../../tmp/" 9 | 10 | find . -mindepth 3 -maxdepth 3 -type d -print | grep -v 'tmp' | while read d; do 11 | stringA=$up$d 12 | len=${#stringA} 13 | len=$((len-7)) 14 | echo ${stringA} 15 | stringA=${stringA::len} 16 | (cd $d && mogrify -path $stringA -resize 160x120 -crop 120x120+20+0 -quality 100 *.jpg) 17 | done 18 | -------------------------------------------------------------------------------- /data_preprocessing/resize.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo hello world 3 | mogrify -path ../32_32 -resize -thumbnail 32x32 *.png 4 | mogrify -path thumbnail-directory -thumbnail 100x100 * 5 | 6 | mkdir 32_32 7 | cd 640_480 8 | mogrify -path ./../32_32 -resize 48x32 -crop 32x32+0+0 -quality 100 *.jpg -------------------------------------------------------------------------------- /data_preprocessing/resize_grey.command: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #-type Grayscale 3 | #cd $d && mogrify -path $stringA -type Grayscale -resize 160x120 -crop 80x80+40+0 -quality 100 *.jpg 4 | rm -R tmp_grey 5 | mkdir tmp_grey 6 | 7 | 8 | find . -mindepth 2 -maxdepth 2 -type d -print | grep -v 'tmp' | cpio -pd ./tmp_grey/ 9 | 10 | up=$"../../../tmp_grey/" 11 | 12 | find . -mindepth 3 -maxdepth 3 -type d -print | grep -v 'tmp' | while read d; do 13 | stringA=$up$d 14 | len=${#stringA} 15 | len=$((len-7)) 16 | echo ${stringA} 17 | stringA=${stringA::len} 18 | # echo ${stringA} 19 | (cd $d && mogrify -path $stringA -resize 160x120 -crop 40x40+60+40 -fx '(r+g+b)/3' -quality 100 *.jpg) 20 | done 21 | 22 | chmod 777 tmp 23 | -------------------------------------------------------------------------------- /data_preprocessing/resize_no_gun.command: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #-type Grayscale 3 | #cd $d && mogrify -path $stringA -type Grayscale -resize 160x120 -crop 80x80+40+0 -quality 100 *.jpg 4 | rm -R tmp 5 | mkdir tmp 6 | find . -mindepth 2 -maxdepth 2 -type d -print | grep -v 'tmp' | cpio -pd ./tmp 7 | 8 | up=$"../../../tmp/" 9 | 10 | find . -mindepth 3 -maxdepth 3 -type d -print | grep -v 'tmp' | while read d; do 11 | stringA=$up$d 12 | len=${#stringA} 13 | len=$((len-7)) 14 | echo ${stringA} 15 | stringA=${stringA::len} 16 | (cd $d && mogrify -path $stringA -resize 160x120 -crop 160x80+0+0 -quality 100 *.jpg) 17 | done 18 | 19 | chmod 777 tmp 20 | -------------------------------------------------------------------------------- /data_preprocessing/tar_all.command: -------------------------------------------------------------------------------- 1 | ext=$".tar.gz" 2 | find ./tmp/ -type d -maxdepth 1 -mindepth 1 | while read d; do 3 | tar -zcvf $d$ext $d/ 4 | done -------------------------------------------------------------------------------- /data_preprocessing/to32_32.command: -------------------------------------------------------------------------------- 1 | mkdir 32_32 2 | cd 640_480 3 | mogrify -path ./../32_32 -resize 48x32 -crop 32x32+0+0 -quality 100 *.jpg -------------------------------------------------------------------------------- /experiments.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import utils as ut 4 | import input 5 | import DoomModel as dm 6 | import pickle 7 | from datetime import datetime as dt 8 | import sys 9 | 10 | FLAGS = tf.app.flags.FLAGS 11 | 12 | 13 | def search_learning_rate(lrs=[0.001, 0.0004, 0.0001, 0.00003,], 14 | epochs=500): 15 | FLAGS.suffix = 'grid_lr' 16 | ut.print_info('START: search_learning_rate', color=31) 17 | 18 | best_result, best_args = None, None 19 | result_summary, result_list = [], [] 20 | 21 | for lr in lrs: 22 | ut.print_info('STEP: search_learning_rate', color=31) 23 | FLAGS.learning_rate = lr 24 | model = model_class() 25 | meta, accuracy_by_epoch = model.train(epochs) 26 | result_list.append((ut.to_file_name(meta), accuracy_by_epoch)) 27 | best_accuracy = np.min(accuracy_by_epoch) 28 | result_summary.append('\n\r lr:%2.5f \tq:%.2f' % (lr, best_accuracy)) 29 | if best_result is None or best_result > best_accuracy: 30 | best_result = best_accuracy 31 | best_args = lr 32 | 33 | meta = {'suf': 'grid_lr_bs', 'e': epochs, 'lrs': lrs, 'acu': best_result, 34 | 'bs': FLAGS.batch_size, 'h': model.get_layer_info()} 35 | pickle.dump(result_list, open('search_learning_rate%d.txt' % epochs, "wb")) 36 | ut.plot_epoch_progress(meta, result_list) 37 | print(''.join(result_summary)) 38 | ut.print_info('BEST Q: %d IS ACHIEVED FOR LR: %f' % (best_result, best_args), 36) 39 | 40 | 41 | def search_batch_size(bss=[50], strides=[1, 2, 5, 20], epochs=500): 42 | FLAGS.suffix = 'grid_bs' 43 | ut.print_info('START: search_batch_size', color=31) 44 | best_result, best_args = None, None 45 | result_summary, result_list = [], [] 46 | 47 | print(bss) 48 | for bs in bss: 49 | for stride in strides: 50 | ut.print_info('STEP: search_batch_size %d %d' % (bs, stride), color=31) 51 | FLAGS.batch_size = bs 52 | FLAGS.stride = stride 53 | model = model_class() 54 | start = dt.now() 55 | # meta, accuracy_by_epoch = model.train(epochs * int(bs / bss[0])) 56 | meta, accuracy_by_epoch = model.train(epochs) 57 | meta['str'] = stride 58 | meta['t'] = int((dt.now() - start).seconds) 59 | result_list.append((ut.to_file_name(meta)[22:], accuracy_by_epoch)) 60 | best_accuracy = np.min(accuracy_by_epoch) 61 | result_summary.append('\n\r bs:%d \tst:%d \tq:%.2f' % (bs, stride, best_accuracy)) 62 | if best_result is None or best_result > best_accuracy: 63 | best_result = best_accuracy 64 | best_args = (bs, stride) 65 | 66 | meta = {'suf': 'grid_batch_bs', 'e': epochs, 'acu': best_result, 67 | 'h': model.get_layer_info()} 68 | pickle.dump(result_list, open('search_batch_size%d.txt' % epochs, "wb")) 69 | ut.plot_epoch_progress(meta, result_list) 70 | print(''.join(result_summary)) 71 | 72 | ut.print_info('BEST Q: %d IS ACHIEVED FOR bs, st: %d %d' % (best_result, best_args[0], best_args[1]), 36) 73 | 74 | 75 | def search_layer_sizes(epochs=500): 76 | FLAGS.suffix = 'grid_h' 77 | ut.print_info('START: search_layer_sizes', color=31) 78 | best_result, best_args = None, None 79 | result_summary, result_list = [], [] 80 | 81 | for _, h_encoder in enumerate([300, 700, 2500]): 82 | for _, h_decoder in enumerate([300, 700, 2500]): 83 | for _, h_narrow in enumerate([3]): 84 | model = model_class() 85 | model.layer_encoder = h_encoder 86 | model.layer_narrow = h_narrow 87 | model.layer_decoder = h_decoder 88 | layer_info = str(model.get_layer_info()) 89 | ut.print_info('STEP: search_layer_sizes: ' + str(layer_info), color=31) 90 | 91 | meta, accuracy_by_epoch = model.train(epochs) 92 | result_list.append((layer_info, accuracy_by_epoch)) 93 | best_accuracy = np.min(accuracy_by_epoch) 94 | result_summary.append('\n\r h:%s \tq:%.2f' % (layer_info, best_accuracy)) 95 | if best_result is None or best_result > best_accuracy: 96 | best_result = best_accuracy 97 | best_args = layer_info 98 | 99 | meta = {'suf': 'grid_H_bs', 'e': epochs, 'acu': best_result, 100 | 'bs': FLAGS.batch_size, 'h': model.get_layer_info()} 101 | print(''.join(result_summary)) 102 | pickle.dump(result_list, open('search_layer_sizes%d.txt' % epochs, "wb")) 103 | ut.print_info('BEST Q: %d IS ACHIEVED FOR H: %s' % (best_result, best_args), 36) 104 | ut.plot_epoch_progress(meta, result_list) 105 | 106 | 107 | def search_layer_sizes_follow_up(): 108 | """train further 2 best models""" 109 | FLAGS.save_every = 200 110 | for i in range(4): 111 | model = model_class() 112 | model.layer_encoder = 500 113 | model.layer_narrow = 3 114 | model.layer_decoder = 100 115 | model.train(600) 116 | 117 | model = model_class() 118 | model.layer_encoder = 500 119 | model.layer_narrow = 12 120 | model.layer_decoder = 500 121 | model.train(600) 122 | 123 | 124 | def print_reconstructions_along_with_originals(): 125 | FLAGS.load_from_checkpoint = './tmp/doom_bs__act|sigmoid__bs|20__h|500|5|500__init|na__inp|cbd4__lr|0.0004__opt|AO' 126 | model = model_class() 127 | files = ut.list_encodings(FLAGS.save_path) 128 | last_encoding = files[-1] 129 | print(last_encoding) 130 | take_only = 20 131 | data = np.loadtxt(last_encoding)[0:take_only] 132 | reconstructions = model.decode(data) 133 | original, _ = input.get_images(FLAGS.input_path, at_most=take_only) 134 | ut.print_side_by_side(original, reconstructions) 135 | 136 | 137 | def train_couple_8_models(): 138 | FLAGS.input_path = '../data/tmp/8_pos_delay_3/img/' 139 | 140 | model = model_class() 141 | model.set_layer_sizes([500, 5, 500]) 142 | for i in range(10): 143 | model.train(1000) 144 | 145 | model = model_class() 146 | model.set_layer_sizes([1000, 10, 1000]) 147 | for i in range(20): 148 | model.train(1000) 149 | 150 | 151 | if __name__ == "__main__": 152 | # run function if provided as console params 153 | epochs = 100 154 | model_class = dm.DoomModel 155 | experiment = search_learning_rate 156 | 157 | if len(sys.argv) > 1: 158 | print(sys.argv) 159 | experiment = sys.argv[1] 160 | if experiment not in locals(): 161 | ut.print_info('Function "%s" not found. List of available functions:' % experiment) 162 | ut.print_info('\n'.join([x for x in locals() if 'search' in x])) 163 | exit(0) 164 | experiment = locals()[experiment] 165 | if len(sys.argv) > 2: 166 | epochs = int(sys.argv[2]) 167 | if len(sys.argv) > 3: 168 | m = __import__(sys.argv[3]) 169 | model_class = getattr(m, sys.argv[3]) 170 | 171 | FLAGS.suffix = 'grid' 172 | # FLAGS.input_path = '../data/tmp/8_pos_delay/img/' 173 | 174 | experiment(epochs=epochs) 175 | 176 | # search_layer_sizes(epochs=epochs) 177 | # search_batch_size(epochs=epochs) 178 | # FLAGS.batch_size = 40 179 | -------------------------------------------------------------------------------- /input.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import utils as ut 4 | import json 5 | import scipy.ndimage.filters as filters 6 | import time 7 | from PIL import Image 8 | import tarfile 9 | import io 10 | 11 | 12 | INPUT_FOLDER = '../data/circle_basic_1/img/32_32' 13 | 14 | 15 | def _is_combination_of_image_depth(folder): 16 | return '/dep' not in folder and '/img' not in folder 17 | 18 | 19 | def get_action_data(folder): 20 | folder = folder.replace('/tmp_grey', '') 21 | folder = folder.replace('/tmp', '') 22 | folder = folder.replace('/img', '') 23 | folder = folder.replace('/dep', '') 24 | file = os.path.join(folder, 'action.txt') 25 | if not os.path.exists(file): 26 | return np.asarray([]) 27 | action_data = json.load(open(file, 'r'))[:] 28 | # print(action_data) 29 | res = [] 30 | # for i, action in enumerate(action_data): 31 | # print(action) 32 | # res.append( 33 | # ( 34 | # action[0], 35 | # action[1], 36 | # action[2][3] or action[2][4] or action[2][5] or action[2][6], 37 | # action[2][18] != 0 38 | # ) 39 | # ) 40 | # print([print(x[0]) for x in action_data[0:10]]) 41 | res = [x[3][:2] for x in action_data] 42 | return np.abs(np.asarray(res)) 43 | 44 | 45 | def read_ds_zip(path): 46 | dep, img = {}, {} 47 | tar = tarfile.open(path, "r:gz") 48 | 49 | for member in tar.getmembers(): 50 | if '.jpg' not in member.name or not ('/dep/' in member.name or '/img/' in member.name): 51 | # print('skipped', member) 52 | continue 53 | collection = dep if '/dep/' in member.name else img 54 | index = int(member.name.split('/')[-1][1:-4]) 55 | f = tar.extractfile(member) 56 | if f is not None: 57 | content = f.read() 58 | image = Image.open(io.BytesIO(content)) 59 | collection[index] = np.array(image) 60 | assert len(img) == len(dep) 61 | 62 | shape = [len(img)] + list(img[index].shape) 63 | shape[-1] += 1 64 | 65 | dataset = np.zeros(shape, np.uint8) 66 | for i, k in enumerate(sorted(img)): 67 | dataset[i, ..., :-1] = img[k] 68 | dataset[i, ..., -1] = dep[k] 69 | return dataset#, shape[1:] 70 | 71 | 72 | def get_shape_zip(path): 73 | tar = tarfile.open(path, "r:gz") 74 | for member in tar.getmembers(): 75 | if '.jpg' not in member.name or '/img/' not in member.name: 76 | continue 77 | f = tar.extractfile(member) 78 | content = f.read() 79 | image = Image.open(io.BytesIO(content)) 80 | shape = list(np.array(image).shape) 81 | shape[-1] += 1 82 | return shape 83 | 84 | 85 | def rescale_ds(ds, min, max): 86 | ut.print_info('rescale call: (min: %s, max: %s) %d' % (str(min), str(max), len(ds))) 87 | if max is None: 88 | return np.asarray(ds) - np.min(ds) 89 | ds_min, ds_max = np.min(ds), np.max(ds) 90 | ds_gap = ds_max - ds_min 91 | scale_factor = (max - min) / ds_gap 92 | ds = np.asarray(ds) * scale_factor 93 | shift_factor = min - np.min(ds) 94 | ds += shift_factor 95 | return ds 96 | 97 | 98 | def get_input_name(input_folder): 99 | spliter = '/img/' if '/img/' in input_folder else '/dep/' 100 | main_part = input_folder.split(spliter)[0] 101 | name = main_part.split('/')[-1] 102 | name = name.replace('.tar.gz', '') 103 | ut.print_info('input folder: %s -> %s' % (input_folder.split('/'), name)) 104 | return name 105 | 106 | 107 | def permute_array(array, random_state=None): 108 | return permute_data((array,))[0] 109 | 110 | 111 | def permute_data(arrays, random_state=None): 112 | """Permute multiple numpy arrays with the same order.""" 113 | if any(len(a) != len(arrays[0]) for a in arrays): 114 | raise ValueError('All arrays must be the same length.') 115 | if not random_state: 116 | random_state = np.random 117 | order = random_state.permutation(len(arrays[0])) 118 | return [a[order] for a in arrays] 119 | 120 | 121 | def apply_gaussian(images, sigma): 122 | if sigma == 0: 123 | return images 124 | 125 | res = images.copy() 126 | for i, image in enumerate(res): 127 | for channel in range(image.shape[-1]): 128 | image[:, :, channel] = filters.gaussian_filter(image[:, :, channel], sigma) 129 | return res 130 | 131 | 132 | def permute_array_in_series(array, series_length, allow_shift=True): 133 | res, permutation = permute_data_in_series((array,), series_length) 134 | return res[0], permutation 135 | 136 | 137 | def permute_data_in_series(arrays, series_length, allow_shift=True): 138 | shift_possibilities = len(arrays[0]) % series_length 139 | series_count = int(len(arrays[0]) / series_length) 140 | 141 | shift = 0 142 | if allow_shift: 143 | if shift_possibilities == 0: 144 | shift_possibilities += series_length 145 | series_count -= 1 146 | shift = np.random.randint(0, shift_possibilities+1, 1, dtype=np.int32)[0] 147 | 148 | series = np.arange(0, series_count * series_length)\ 149 | .astype(np.int32)\ 150 | .reshape((series_count, series_length)) 151 | 152 | series = np.random.permutation(series) 153 | data_permutation = series.reshape((series_count * series_length)) 154 | data_permutation += shift 155 | 156 | remaining_elements = np.arange(0, len(arrays[0])).astype(np.int32) 157 | remaining_elements = np.delete(remaining_elements, data_permutation) 158 | data_permutation = np.concatenate((data_permutation, remaining_elements)) 159 | 160 | # print('assert', len(arrays[0]), len(data_permutation)) 161 | assert len(data_permutation) == len(arrays[0]) 162 | return [a[data_permutation] for a in arrays], data_permutation 163 | 164 | 165 | def pad_set(set, batch_size): 166 | length = len(set) 167 | if length % batch_size == 0: 168 | return set 169 | padding_len = batch_size - length % batch_size 170 | if padding_len != 0: 171 | pass 172 | # ut.print_info('Non-zero padding: %d' % padding_len, color=31) 173 | # print('pad set', set.shape, select_random(padding_len, set=set).shape) 174 | return np.concatenate((set, select_random(padding_len, set=set))) 175 | 176 | 177 | def select_random(n, length=None, set=None): 178 | assert length is None or set is None 179 | length = length if set is None else len(set) 180 | select = np.random.permutation(np.arange(length, dtype=np.int))[:n] 181 | if set is None: 182 | return select 183 | else: 184 | return set[select] 185 | 186 | if __name__ == '__main__': 187 | print(read_ds_zip('/home/eugene/repo/data/tmp/romb8.5.6.tar.gz').shape) -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import utils as ut 4 | import sys 5 | import os 6 | import sklearn.metrics as m 7 | import visualization as vis 8 | import matplotlib.pyplot as plt 9 | import matplotlib.ticker as ticker 10 | 11 | 12 | 13 | def get_evaluation(path): 14 | f = ut.get_latest_file(path, filter=r'.*.npy$') 15 | print(f) 16 | dict = np.load(f).item() 17 | return dict 18 | 19 | 20 | def distance(ref, pred): 21 | error = ref - pred 22 | return l2(error) 23 | 24 | 25 | def l2(error): 26 | error = error ** 2 27 | error = np.sum(error, axis=1) 28 | return np.sqrt(error) 29 | 30 | 31 | def distance_improvement(prediction_dist, naive_dist): 32 | pred_mean, naive_mean = np.mean(prediction_dist), np.mean(naive_dist) 33 | improvement = naive_mean-pred_mean if naive_mean-pred_mean > 0 else 0 34 | improvement = improvement / naive_mean 35 | # print('predictive error(naive): %.9f (%.9f) -> %.2f%%' 36 | # % (pred_mean, naive_mean, improvement*100)) 37 | return improvement, pred_mean, naive_mean 38 | 39 | 40 | def distance_binary_improvement(prediction_dist, naive_dist): 41 | pairwise = naive_dist - prediction_dist 42 | pairwise[pairwise > 0] = 1 43 | pairwise[pairwise < 0] = 0 44 | fraction = np.mean(pairwise) 45 | # print('Pairwise error improved for : %f%%' % (fraction*100), pairwise) 46 | return fraction 47 | 48 | 49 | def nn_metric(point_array): 50 | """ 51 | For every frame how often previous frame t-1 and next frame t+1 are within top-2 nearest neighbours 52 | :param point_array: point array 53 | :return: 54 | """ 55 | total = 0 56 | for i in range(1, len(point_array)-1): 57 | x = point_array - point_array[i] 58 | d = l2(x) 59 | indexes = np.argsort(d)[:3] 60 | # assert i in indexes or d[indexes[0]][0] == 0 61 | if i-1 in indexes: 62 | total += 1 63 | if i+1 in indexes: 64 | total += 1 65 | # print(i, i-1 in indexes, i+1 in indexes, indexes) 66 | metric = total/(len(point_array) - 2) / 2 67 | # print('NN metric: %.7f%%' % (metric * 100)) 68 | return metric 69 | 70 | 71 | def nn_metric_pred(prediction, target): 72 | """ 73 | For every frame how often previous frame t-1 and next frame t+1 are within top-2 nearest neighbours 74 | :param target: point array 75 | :return: 76 | """ 77 | total = 0 78 | for i in range(len(target)): 79 | x = target - prediction[i] 80 | d = l2(x) 81 | index = np.argmin(d) 82 | # print(index, i, d) 83 | if i == index: 84 | total += 1 85 | metric = total/len(target) 86 | # print('NN metric for preditcion: %.7f%%' % (metric * 100)) 87 | return metric 88 | 89 | 90 | def test_nn(): 91 | enc = np.arange(0, 100).reshape((100, 1)) 92 | # enc.transpose() 93 | # print(enc.shape) 94 | # print(nn_metric(enc)) 95 | assert nn_metric(enc) == 1. 96 | 97 | 98 | def test_nn_pred(): 99 | enc = np.arange(0, 100).reshape((100, 1)) 100 | # enc.transpose() 101 | # print(enc.shape) 102 | # print(nn_metric_pred(enc, enc+0.2)) 103 | assert nn_metric_pred(enc, enc) == 1. 104 | 105 | 106 | def reco_error(x, y): 107 | delta = x-y 108 | error = m.mean_squared_error(x.flatten(), y.flatten()) 109 | return error 110 | 111 | 112 | def print_folder_metrics(path): 113 | eval = get_evaluation(path) 114 | enc = eval['enc'] 115 | pred = enc[1:-1]*2 - enc[0:-2] 116 | ref = enc[2:] 117 | 118 | # print(enc[0], enc[1], pred[0], enc[2]) 119 | 120 | pred_to_target_dist, next_dist = distance(ref, pred), distance(enc[1:-1], enc[2:]) 121 | pl2, pred_d, naiv_d = distance_improvement(pred_to_target_dist, next_dist) 122 | pb = distance_binary_improvement(pred_to_target_dist, next_dist) 123 | pnn = nn_metric(enc) 124 | pnnp = (nn_metric_pred(pred, ref)*100) 125 | lreco = reco_error(eval['rec'], eval['blu']) 126 | info = '%.3f & %.3f & %.3f & %.2f' % (pl2, pb, pnn, lreco) 127 | print(info) 128 | print('pimp:%f(%f/%f) & pb:%f & pnn:%f' % (pl2, pred_d, naiv_d, pb, pnn), 'pnnp: %f' % pnnp) 129 | 130 | return info 131 | 132 | 133 | def plot_single_cross_section_3d(data, select, subplot): 134 | data = data[:, select] 135 | # subplot.scatter(data[:, 0], data[:, 1], s=20, lw=0, edgecolors='none', alpha=1.0, 136 | # subplot.plot(data[:, 0], data[:, 1], data[:, 2], color='black', lw=1, alpha=0.4) 137 | 138 | d = data 139 | # subplot.plot(d[[-1, 0], 0], d[[-1, 0], 1], d[[-1, 0], 2], lw=1, alpha=0.8, color='red') 140 | # subplot.scatter(d[[-1, 0], 0], d[[-1, 0], 1], d[[-1, 0], 2], lw=10, alpha=0.3, marker=".", color='b') 141 | d = data 142 | subplot.scatter(d[:, 0], d[:, 1], d[:, 2], s=4, alpha=1.0, lw=0.5, 143 | c=vis._build_radial_colors(len(d)), 144 | marker=".", 145 | cmap=plt.cm.hsv) 146 | subplot.plot(data[:, 0], data[:, 1], data[:, 2], color='black', lw=0.2, alpha=0.9) 147 | 148 | subplot.set_xlim([-0.01, 1.01]) 149 | subplot.set_ylim([-0.01, 1.01]) 150 | subplot.set_zlim([-0.01, 1.01]) 151 | ticks = [] 152 | subplot.xaxis.set_ticks(ticks) 153 | subplot.yaxis.set_ticks(ticks) 154 | subplot.zaxis.set_ticks(ticks) 155 | subplot.xaxis.set_major_formatter(ticker.FormatStrFormatter('%1.0f')) 156 | subplot.yaxis.set_major_formatter(ticker.FormatStrFormatter('%1.0f')) 157 | 158 | 159 | def plot_single_cross_section_line(data, select, subplot): 160 | data = data[:, select] 161 | # subplot.scatter(data[:, 0], data[:, 1], s=20, lw=0, edgecolors='none', alpha=1.0, 162 | # subplot.plot(data[:, 0], data[:, 1], data[:, 2], color='black', lw=1, alpha=0.4) 163 | 164 | d = data 165 | # subplot.plot(d[[-1, 0], 0], d[[-1, 0], 1], d[[-1, 0], 2], lw=1, alpha=0.8, color='red') 166 | # subplot.scatter(d[[-1, 0], 0], d[[-1, 0], 1], d[[-1, 0], 2], lw=10, alpha=0.3, marker=".", color='b') 167 | d = data 168 | subplot.plot(data[:, 0], data[:, 1], data[:, 2], color='black', lw=1, alpha=0.4) 169 | 170 | subplot.set_xlim([-0.01, 1.01]) 171 | subplot.set_ylim([-0.01, 1.01]) 172 | subplot.set_zlim([-0.01, 1.01]) 173 | ticks = [] 174 | subplot.xaxis.set_ticks(ticks) 175 | subplot.yaxis.set_ticks(ticks) 176 | subplot.zaxis.set_ticks(ticks) 177 | subplot.xaxis.set_major_formatter(ticker.FormatStrFormatter('%1.0f')) 178 | subplot.yaxis.set_major_formatter(ticker.FormatStrFormatter('%1.0f')) 179 | 180 | 181 | if __name__ == '__main__': 182 | path = os.getcwd() 183 | 184 | for _, paths, _ in os.walk(path): 185 | print('dirs', paths) 186 | break 187 | 188 | if len(paths) == 0: 189 | print_folder_metrics(path) 190 | 191 | eval = get_evaluation(path) 192 | enc = eval['enc'] 193 | 194 | fig = vis.get_figure(shape=[1800, 900, 3]) 195 | # ax = 196 | plot_single_cross_section_3d(enc, [0, 1, 2], plt.subplot(121, projection='3d')) 197 | plot_single_cross_section_line(enc, [0, 1, 2], plt.subplot(122, projection='3d')) 198 | plt.tight_layout() 199 | plt.show() 200 | else: 201 | res = [] 202 | for d in paths + [path]: 203 | print(d) 204 | c_path = os.path.join(path, d) 205 | info = print_folder_metrics(c_path) 206 | res.append('\n%30s:\n%s' % (d, info)) 207 | res = sorted(res) 208 | print('\n'.join(res)) 209 | print(len(paths), len(res)) 210 | exit(0) 211 | 212 | 213 | if 'TensorFlow_DCIGN' in os.getcwd().split('/')[-1]: 214 | # path = '/home/eugene/repo/TensorFlow_DCIGN/tmp/noise.f20_f4__i_grid03.14.c' 215 | # path = '/mnt/code/vd/TensorFlow_DCIGN/tmp/pred.f101_f3__i_romb8.5.6' 216 | path = '/media/eugene/back up/VD_backup/tmp_epoch20_final/pred.16c3s2_32c3s2_32c3s2_23c3_f3__i_romb8.5.6_' 217 | # path = '/media/eugene/back up/VD_backup/tmp_epoch19_inputs/pred.16c3s2_32c3s2_32c3s2_16c3_f100_f3__i_grid.28.gh.360' 218 | 219 | 220 | -------------------------------------------------------------------------------- /model_interpreter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow.contrib.slim as slim 4 | import network_utils as nut 5 | import utils as ut 6 | import re 7 | import os 8 | from Bunch import Bunch 9 | 10 | INPUT = 'input' 11 | FC = 'fully_connected' 12 | CONV = 'convolutional' 13 | POOL = 'max_pooling' 14 | POOL_ARG = 'maxpool_with_args' 15 | DO = 'dropout' 16 | LOSS = 'loss' 17 | 18 | activation_voc = { 19 | 's': tf.nn.sigmoid, 20 | 'r': tf.nn.relu, 21 | 't': tf.nn.tanh, 22 | 'i': None 23 | } 24 | 25 | CONFIG_COLOR = 30 26 | PADDING = 'SAME' 27 | 28 | 29 | def clean_unpooling_masks(layer_config): 30 | """ 31 | Cleans POOL_ARG configs positional information 32 | 33 | :param layer_config: list of layer descriptors 34 | :return: dictinary of ARGMAX_POOL layer name and corresponding mask source 35 | """ 36 | mask_list = [cfg.argmax for cfg in layer_config if cfg.type == POOL_ARG] 37 | for cfg in layer_config: 38 | if cfg.type == POOL_ARG: 39 | cfg.argmax = None 40 | return mask_list 41 | 42 | 43 | def build_autoencoder(input, layer_config): 44 | reuse_model = isinstance(layer_config, list) 45 | if not reuse_model: 46 | layer_config 47 | layer_config = layer_config.replace('_', '-').split('-') 48 | layer_config = [parse_input(input)] + [parse(x) for x in layer_config] 49 | if not reuse_model: 50 | ut.print_info('Model config:', color=CONFIG_COLOR) 51 | enc = build_encoder(input, layer_config, reuse=reuse_model) 52 | dec = build_decoder(enc, layer_config, reuse=reuse_model) 53 | mask_list = clean_unpooling_masks(layer_config) 54 | losses = build_losses(layer_config) 55 | return Bunch( 56 | encode=enc, 57 | decode=dec, 58 | losses=losses, 59 | config=layer_config, 60 | mask_list=mask_list) 61 | 62 | 63 | def build_encoder(net, layer_config, i=1, reuse=False): 64 | if i == len(layer_config): 65 | return net 66 | 67 | cfg = layer_config[i] 68 | cfg.shape = net.get_shape().as_list() 69 | name = cfg.enc_op_name if reuse else None 70 | cfg.ein = net 71 | if cfg.type == FC: 72 | if len(cfg.shape) > 2: 73 | net = slim.flatten(net) 74 | net = slim.fully_connected(net, cfg.size, activation_fn=cfg.activation, 75 | scope=name, reuse=reuse) 76 | elif cfg.type == CONV: 77 | net = slim.conv2d(net, cfg.size, [cfg.kernel, cfg.kernel], stride=cfg.stride, 78 | activation_fn=cfg.activation, padding=PADDING, 79 | scope=name, reuse=reuse) 80 | elif cfg.type == POOL_ARG: 81 | net, cfg.argmax = nut.max_pool_with_argmax(net, cfg.kernel) 82 | # if not reuse: 83 | # mask = nut.fake_arg_max_of_max_pool(cfg.shape, cfg.kernel) 84 | # cfg.argmax_dummy = tf.constant(mask.flatten(), shape=mask.shape) 85 | elif cfg.type == POOL: 86 | net = slim.max_pool2d(net, kernel_size=[cfg.kernel, cfg.kernel], stride=cfg.kernel) 87 | elif cfg.type == DO: 88 | net = tf.nn.dropout(net, keep_prob=cfg.keep_prob) 89 | elif cfg.type == LOSS: 90 | cfg.arg1 = net 91 | elif cfg.type == INPUT: 92 | assert False 93 | 94 | if not reuse: 95 | cfg.enc_op_name = net.name.split('/')[0] 96 | if not reuse: 97 | ut.print_info('\rencoder_%d\t%s\t%s' % (i, str(net), cfg.enc_op_name), color=CONFIG_COLOR) 98 | return build_encoder(net, layer_config, i + 1, reuse=reuse) 99 | 100 | 101 | def build_decoder(net, layer_config, i=None, reuse=False, masks=None): 102 | i = i if i is not None else len(layer_config) - 1 103 | 104 | cfg = layer_config[i] 105 | name = cfg.dec_op_name if reuse else None 106 | if len(layer_config) > i + 1: 107 | if len(layer_config[i + 1].shape) != len(net.get_shape().as_list()): 108 | net = tf.reshape(net, layer_config[i + 1].shape) 109 | 110 | if i < 0 or layer_config[i].type == INPUT: 111 | return net 112 | 113 | if cfg.type == FC: 114 | net = slim.fully_connected(net, int(np.prod(cfg.shape[1:])), scope=name, 115 | activation_fn=cfg.activation, reuse=reuse) 116 | elif cfg.type == CONV: 117 | net = slim.conv2d_transpose(net, cfg.shape[-1], [cfg.kernel, cfg.kernel], stride=cfg.stride, 118 | activation_fn=cfg.activation, padding=PADDING, 119 | scope=name, reuse=reuse) 120 | elif cfg.type == POOL_ARG: 121 | if cfg.argmax is not None or masks is not None: 122 | mask = cfg.argmax if cfg.argmax is not None else masks.pop() 123 | net = nut.unpool(net, mask=mask, stride=cfg.kernel) 124 | else: 125 | net = nut.upsample(net, stride=cfg.kernel, mode='COPY') 126 | elif cfg.type == POOL: 127 | net = nut.upsample(net, cfg.kernel) 128 | elif cfg.type == DO: 129 | pass 130 | elif cfg.type == LOSS: 131 | cfg.arg2 = net 132 | elif cfg.type == INPUT: 133 | assert False 134 | if not reuse: 135 | cfg.dec_op_name = net.name.split('/')[0] 136 | if not reuse: 137 | ut.print_info('\rdecoder_%d \t%s' % (i, str(net)), color=CONFIG_COLOR) 138 | cfg.dout = net 139 | return build_decoder(net, layer_config, i - 1, reuse=reuse, masks=masks) 140 | 141 | 142 | def build_stacked_losses(model): 143 | losses = [] 144 | for i, cfg in enumerate(model.config): 145 | if cfg.type in [FC, CONV]: 146 | input = tf.stop_gradient(cfg.ein, name='stacked_breakpoint_%d' % i) 147 | net = build_encoder(input, [None, model.config[i]], reuse=True) 148 | net = build_decoder(net, [model.config[i]], reuse=True) 149 | losses.append(l2_loss(input, net, name='stacked_loss_%d' % i)) 150 | model.stacked_losses = losses 151 | 152 | 153 | def build_losses(layer_config): 154 | return [] 155 | 156 | 157 | def l2_loss(arg1, arg2, alpha=1.0, name='reco_loss'): 158 | with tf.name_scope(name): 159 | loss = tf.nn.l2_loss(arg1 - arg2) 160 | return alpha * loss 161 | 162 | 163 | def get_activation(descriptor): 164 | if 'c' not in descriptor and 'f' not in descriptor: 165 | return None 166 | activation = tf.nn.relu if 'c' in descriptor else tf.nn.sigmoid 167 | act_descriptor = re.search('[r|s|i|t]&', descriptor) 168 | if act_descriptor is None: 169 | return activation 170 | act_descriptor = act_descriptor.group(0) 171 | return activation_voc[act_descriptor] 172 | 173 | 174 | def _get_cfg_dummy(): 175 | return Bunch(enc_op_name=None, dec_op_name=None) 176 | 177 | 178 | def parse(descriptor): 179 | item = _get_cfg_dummy() 180 | 181 | match = re.match(r'^((\d+c\d+(s\d+)?[r|s|i|t]?)' 182 | r'|(f\d+[r|s|i|t]?)' 183 | r'|(d0?\.?[\d+]?)' 184 | r'|(d0?\.?[\d+]?)' 185 | r'|(p\d+)' 186 | r'|(ap\d+))$', descriptor) 187 | assert match is not None, 'Check your writing: %s (f10i-3c64r-d0.1-p2-ap2)' % descriptor 188 | 189 | 190 | if 'f' in descriptor: 191 | item.type = FC 192 | item.activation = get_activation(descriptor) 193 | item.size = int(re.search('f\d+', descriptor).group(0)[1:]) 194 | elif 'c' in descriptor: 195 | item.type = CONV 196 | item.activation = get_activation(descriptor) 197 | item.kernel = int(re.search('c\d+', descriptor).group(0)[1:]) 198 | stride = re.search('s\d+', descriptor) 199 | item.stride = int(stride.group(0)[1:]) if stride is not None else 1 200 | item.size = int(re.search('\d+c', descriptor).group(0)[:-1]) 201 | elif 'd' in descriptor: 202 | item.type = DO 203 | item.keep_prob = float(descriptor[1:]) 204 | elif 'ap' in descriptor: 205 | item.type = POOL_ARG 206 | item.kernel = int(descriptor[2:]) 207 | elif 'p' in descriptor: 208 | item.type = POOL 209 | item.kernel = int(descriptor[1:]) 210 | elif 'l' in descriptor: 211 | item.type = LOSS 212 | item.loss_type = 'l2' 213 | item.alpha = float(descriptor.split('l')[0]) 214 | else: 215 | print('What is "%s"? Check your writing 16c2i-7c3r-p3-0.01l-f10t-d0.3' % descriptor) 216 | assert False 217 | return item 218 | 219 | 220 | def parse_input(input): 221 | item = _get_cfg_dummy() 222 | item.type = INPUT 223 | item.shape = input.get_shape().as_list() 224 | item.dout = input 225 | return item 226 | 227 | 228 | def _log_graph(): 229 | path = '/tmp/interpreter' 230 | with tf.Session() as sess: 231 | tf.global_variables_initializer() 232 | tf.summary.FileWriter(path, sess.graph) 233 | ut.print_color(os.path.abspath(path), color=33) -------------------------------------------------------------------------------- /model_interpreter_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow.contrib.slim as slim 4 | import network_utils as nut 5 | import utils as ut 6 | import re 7 | import os 8 | from Bunch import Bunch 9 | from model_interpreter import * 10 | 11 | 12 | def _log_graph(): 13 | path = '/tmp/interpreter' 14 | with tf.Session() as sess: 15 | tf.global_variables_initializer() 16 | tf.summary.FileWriter(path, sess.graph) 17 | ut.print_color(os.path.abspath(path), color=33) 18 | 19 | 20 | def _test_parameter_reuse_conv(): 21 | input = tf.placeholder(tf.float32, (2, 120, 120, 4), name='input') 22 | model = build_autoencoder(input, '8c3s2-16c3s2-32c3s2-1c3-f4') 23 | 24 | input = tf.placeholder(tf.float32, (2, 120, 120, 4), name='input') 25 | model = build_autoencoder(input, '8c3s2-16c3s2-32c3s2-1c3-f4') 26 | l2_loss(input, model.decode) 27 | 28 | input = tf.placeholder(tf.float32, (2, 120, 120, 4), name='input') 29 | model = build_autoencoder(input, model.config) 30 | l2_loss(input, model.decode) 31 | 32 | 33 | def _test_parameter_reuse_decoder(): 34 | input = tf.placeholder(tf.float32, (2, 120, 120, 4), name='input') 35 | model = build_autoencoder(input, '8c3s2-16c3s2-32c3s2-16c3-f4') 36 | 37 | input = tf.placeholder(tf.float32, (2, 120, 120, 4), name='input') 38 | model = build_autoencoder(input, '8c3s2-16c3s2-32c3s2-16c3-f4') 39 | 40 | input_enc = tf.placeholder(tf.float32, (2, 120, 120, 4), name='input_encoder') 41 | encoder = build_encoder(input_enc, model.config, reuse=True) 42 | 43 | input_dec = tf.placeholder(tf.float32, (2, 4), name='input_decoder') 44 | decoder = build_decoder(input_dec, model.config, reuse=True) 45 | 46 | l2_loss(input, model.decode) 47 | l2_loss(encoder, model.encode) 48 | l2_loss(decoder, model.decode) 49 | 50 | 51 | def _test_armgax_ae(): 52 | input = tf.placeholder(tf.float32, (2, 120, 120, 4), name='input') 53 | model = build_autoencoder(input, '8c3-ap2-16c3-ap2-32c3-ap2-16c3-f4') 54 | 55 | input = tf.placeholder(tf.float32, (2, 120, 120, 4), name='input') 56 | model = build_autoencoder(input, '8c3-ap2-16c3-ap2-32c3-ap2-16c3-f4') 57 | 58 | input_enc = tf.placeholder(tf.float32, (2, 120, 120, 4), name='input_encoder') 59 | encoder = build_encoder(input_enc, model.config, reuse=True) 60 | 61 | input_dec = tf.placeholder(tf.float32, (2, 4), name='input_decoder') 62 | decoder = build_decoder(input_dec, model.config, reuse=True) 63 | 64 | l2_loss(input, model.decode) 65 | l2_loss(encoder, model.encode) 66 | l2_loss(decoder, model.decode) 67 | 68 | 69 | def _test_multiple_decoders_unpool_wiring(): 70 | input = tf.placeholder(tf.float32, (2, 16, 16, 3), name='input') 71 | model1 = build_autoencoder(input, '8c3-ap2-f4') 72 | model2 = build_autoencoder(input, model1.config) 73 | enc_1 = tf.placeholder(tf.float32, (2, 4), name='enc_1') 74 | enc_2 = tf.placeholder(tf.float32, (2, 4), name='enc_2') 75 | decoder1 = build_decoder(enc_1, model1.config, reuse=True, masks=model1.mask_list) 76 | decoder2 = build_decoder(enc_2, model2.config, reuse=True, masks=model2.mask_list) 77 | 78 | 79 | def _visualize_models(): 80 | input = tf.placeholder(tf.float32, (128, 160, 120, 4), name='input_fc') 81 | model = build_autoencoder(input, 'f100-f3') 82 | loss= l2_loss(input, model.decode, name='Loss_reconstruction_FC') 83 | 84 | input = tf.placeholder(tf.float32, (128, 160, 120, 4), name='input_conv') 85 | model = build_autoencoder(input, '16c3s2-32c3s2-32c3s2-16c3-f3') 86 | loss= l2_loss(input, model.decode, name='Loss_reconstruction_conv') 87 | 88 | input = tf.placeholder(tf.float32, (128, 160, 120, 4), name='input_wwae') 89 | model = build_autoencoder(input, '16c3-ap2-32c3-ap2-16c3-f3') 90 | loss= l2_loss(input, model.decode, name='Loss_reconstruction_WWAE') 91 | 92 | 93 | def _test_stacked_ae(): 94 | input = tf.placeholder(tf.float32, (2, 16, 16, 3), name='input') 95 | model1 = build_autoencoder(input, 'f500-f100-f5') 96 | losses = build_stacked_losses(model1) 97 | 98 | 99 | if __name__ == '__main__': 100 | # print(re.match('\d+c\d+(s\d+)?[r|s|i|t]?', '8c3s2')) 101 | # model = build_autoencoder(tf.placeholder(tf.float32, (2, 16, 16, 3), name='input'), '8c3s2-16c3s2-30c3s2-16c3-f4') 102 | # _test_multiple_decoders_unpool_wiring() 103 | # _visualize_models() 104 | x = _test_stacked_ae() 105 | 106 | # build_autoencoder(tf.placeholder(tf.float32, (2, 16, 16, 3), name='input'), '10c3-f100-f10') 107 | # _test_parameter_reuse_conv() 108 | # _test_parameter_reuse_decoder() 109 | # _test_armgax_ae() 110 | _log_graph() 111 | -------------------------------------------------------------------------------- /network_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow.contrib.slim as slim 4 | import tools.checkpoint_utils as ch_utils 5 | import scipy.stats as st 6 | import inspect 7 | 8 | # POOLING 9 | 10 | 11 | def max_pool_with_argmax(net, stride): 12 | """ 13 | Tensorflow default implementation does not provide gradient operation on max_pool_with_argmax 14 | Therefore, we use max_pool_with_argmax to extract mask and 15 | plain max_pool for, eeem... max_pooling. 16 | """ 17 | with tf.name_scope('MaxPoolArgMax'): 18 | _, mask = tf.nn.max_pool_with_argmax( 19 | net, 20 | ksize=[1, stride, stride, 1], 21 | strides=[1, stride, stride, 1], 22 | padding='SAME') 23 | mask = tf.stop_gradient(mask) 24 | net = slim.max_pool2d(net, kernel_size=[stride, stride], stride=stride) 25 | return net, mask 26 | 27 | 28 | def fake_arg_max_of_max_pool(shape, stride=2): 29 | assert shape[1] % stride == 0 and shape[2] % stride == 0, \ 30 | 'Smart padding is not supported. Indexes %s are not multiple of stride:%d' % (str(shape[1:3]), stride) 31 | mask = np.arange(np.prod(shape[1:])) 32 | mask = mask.reshape(shape[1:]) 33 | mask = mask[::stride, ::stride, :] 34 | mask = np.tile(mask, (shape[0], 1, 1, 1)) 35 | return mask 36 | 37 | 38 | # Thank you, @https://github.com/Pepslee 39 | def unpool(net, mask, stride=2): 40 | assert mask is not None 41 | with tf.name_scope('UnPool2D'): 42 | ksize = [1, stride, stride, 1] 43 | input_shape = net.get_shape().as_list() 44 | # calculation new shape 45 | output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]) 46 | # calculation indices for batch, height, width and feature maps 47 | one_like_mask = tf.ones_like(mask) 48 | batch_range = tf.reshape(tf.range(output_shape[0], dtype=tf.int64), shape=[input_shape[0], 1, 1, 1]) 49 | b = one_like_mask * batch_range 50 | y = mask // (output_shape[2] * output_shape[3]) 51 | x = mask % (output_shape[2] * output_shape[3]) // output_shape[3] 52 | feature_range = tf.range(output_shape[3], dtype=tf.int64) 53 | f = one_like_mask * feature_range 54 | # transpose indices & reshape update values to one dimension 55 | updates_size = tf.size(net) 56 | indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size])) 57 | values = tf.reshape(net, [updates_size]) 58 | ret = tf.scatter_nd(indices, values, output_shape) 59 | return ret 60 | 61 | 62 | def upsample(net, stride, mode='ZEROS'): 63 | """ 64 | Imitate reverse operation of Max-Pooling by either placing original max values 65 | into a fixed postion of upsampled cell: 66 | [0.9] =>[[.9, 0], (stride=2) 67 | [ 0, 0]] 68 | or copying the value into each cell: 69 | [0.9] =>[[.9, .9], (stride=2) 70 | [ .9, .9]] 71 | 72 | :param net: 4D input tensor with [batch_size, width, heights, channels] axis 73 | :param stride: 74 | :param mode: string 'ZEROS' or 'COPY' indicating which value to use for undefined cells 75 | :return: 4D tensor of size [batch_size, width*stride, heights*stride, channels] 76 | """ 77 | assert mode in ['COPY', 'ZEROS'] 78 | with tf.name_scope('Upsampling'): 79 | net = _upsample_along_axis(net, 2, stride, mode=mode) 80 | net = _upsample_along_axis(net, 1, stride, mode=mode) 81 | return net 82 | 83 | 84 | def _upsample_along_axis(volume, axis, stride, mode='ZEROS'): 85 | shape = volume.get_shape().as_list() 86 | 87 | assert mode in ['COPY', 'ZEROS'] 88 | assert 0 <= axis < len(shape) 89 | 90 | target_shape = shape[:] 91 | target_shape[axis] *= stride 92 | 93 | padding = tf.zeros(shape, dtype=volume.dtype) if mode == 'ZEROS' else volume 94 | parts = [volume] + [padding for _ in range(stride - 1)] 95 | assert list(inspect.signature(tf.concat).parameters.items())[1][0] == 'axis', 'Wrong TF version' 96 | volume = tf.concat(parts, min(axis+1, len(shape)-1)) 97 | volume = tf.reshape(volume, target_shape) 98 | return volume 99 | 100 | 101 | # VARIABLES 102 | 103 | 104 | def print_model_info(trainable=False): 105 | if not trainable: 106 | for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): 107 | print(v.name, v.get_shape()) 108 | else: 109 | for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): 110 | value = v.eval() 111 | print('TRAINABLE_VARIABLES', v.name, v.get_shape(), 'm:%.4f v:%.4f' % (value.mean(), value.std())) 112 | 113 | 114 | def list_checkpoint_vars(folder): 115 | # print(folder) 116 | f = ch_utils.list_variables(folder) 117 | # print(f) 118 | print('\n'.join(map(str, f))) 119 | 120 | 121 | def get_variable(checkpoint, name): 122 | var = ch_utils.load_variable(tf.train.latest_checkpoint(checkpoint), name) 123 | return var 124 | 125 | 126 | def scope_wrapper(scope_name): 127 | def scope_decorator(func): 128 | def func_wrapper(*args, **kwargs): 129 | with tf.name_scope(scope_name): 130 | return func(*args, **kwargs) 131 | return func_wrapper 132 | return scope_decorator 133 | 134 | 135 | # Gaussian blur 136 | 137 | 138 | def _build_gaussian_kernel(k_size, nsig, channels): 139 | interval = (2 * nsig + 1.) / k_size 140 | x = np.linspace(-nsig - interval / 2., nsig + interval / 2., k_size + 1) 141 | kern1d = np.diff(st.norm.cdf(x)) 142 | kernel_raw = np.sqrt(np.outer(kern1d, kern1d)) 143 | kernel = kernel_raw / kernel_raw.sum() 144 | out_filter = np.array(kernel, dtype=np.float32) 145 | out_filter = out_filter.reshape((k_size, k_size, 1, 1)) 146 | out_filter = np.repeat(out_filter, channels, axis=2) 147 | return out_filter 148 | 149 | 150 | def blur_gaussian(input, sigma, filter_size): 151 | num_channels = input.get_shape().as_list()[3] 152 | with tf.variable_scope('gaussian_filter'): 153 | kernel = _build_gaussian_kernel(filter_size, sigma, num_channels) 154 | kernel = tf.constant(kernel.flatten(), shape=kernel.shape, name='gauss_weight') 155 | output = tf.nn.depthwise_conv2d(input, kernel, [1, 1, 1, 1], padding='SAME') 156 | return output, kernel 157 | 158 | 159 | def nan_to_zero(tensor): 160 | return tf.where(tf.is_nan(tensor), tf.zeros_like(tensor), tensor) -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yselivonchyk/TensorFlow_DCIGN/ff8d85f3a7b7ca1e5c3f50ff003a1c09a70067cd/tools/__init__.py -------------------------------------------------------------------------------- /tools/checkpoint_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tools to work with checkpoints.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import six 23 | import tensorflow as tf 24 | from tensorflow.python.ops import gen_io_ops 25 | from tensorflow.python.ops import state_ops 26 | from tensorflow.python.ops import variable_scope as vs 27 | from tensorflow.python.ops import variables 28 | from tensorflow.python.platform import gfile 29 | from tensorflow.python.platform import tf_logging as logging 30 | from tensorflow.python.training import saver 31 | from tensorflow.python.training import training as train 32 | 33 | __all__ = [ 34 | "load_checkpoint", 35 | "load_variable", 36 | "list_variables", 37 | "init_from_checkpoint"] 38 | 39 | 40 | def _get_checkpoint_filename(filepattern): 41 | """Returns checkpoint filename given directory or specific filepattern.""" 42 | import os 43 | if os.path.isdir(filepattern): 44 | print('nope', filepattern) 45 | res = tf.train.latest_checkpoint(filepattern) 46 | print('but here', res) 47 | return res 48 | return filepattern 49 | 50 | 51 | def load_checkpoint(filepattern): 52 | """Returns CheckpointReader for latest checkpoint. 53 | 54 | Args: 55 | filepattern: Directory with checkpoints file or path to checkpoint. 56 | 57 | Returns: 58 | `CheckpointReader` object. 59 | 60 | Raises: 61 | ValueError: if checkpoint_dir doesn't have 'checkpoint' file or checkpoints. 62 | """ 63 | filename = _get_checkpoint_filename(filepattern) 64 | if filename is None: 65 | raise ValueError("Couldn't find 'checkpoint' file or checkpoints in " 66 | "given directory %s" % filepattern) 67 | return train.NewCheckpointReader(filename) 68 | 69 | 70 | def load_variable(checkpoint_dir, name): 71 | """Returns a Tensor with the contents of the given variable in the checkpoint. 72 | 73 | Args: 74 | checkpoint_dir: Directory with checkpoints file or path to checkpoint. 75 | name: Name of the tensor to return. 76 | 77 | Returns: 78 | `Tensor` object. 79 | """ 80 | # TODO(b/29227106): Fix this in the right place and remove this. 81 | if name.endswith(":0"): 82 | name = name[:-2] 83 | reader = load_checkpoint(checkpoint_dir) 84 | return reader.get_tensor(name) 85 | 86 | 87 | def list_variables(checkpoint_dir): 88 | """Returns list of all variables in the latest checkpoint. 89 | 90 | Args: 91 | checkpoint_dir: Directory with checkpoints file or path to checkpoint. 92 | 93 | Returns: 94 | List of tuples `(name, shape)`. 95 | """ 96 | reader = load_checkpoint(checkpoint_dir) 97 | variable_map = reader.get_variable_to_shape_map() 98 | names = sorted(variable_map.keys()) 99 | result = [] 100 | for name in names: 101 | result.append((name, variable_map[name])) 102 | return result 103 | 104 | 105 | # pylint: disable=protected-access 106 | # Currently variable_scope doesn't provide very good APIs to access 107 | # all variables under scope and retrieve and check existing scopes. 108 | # TODO(ipolosukhin): Refactor variable_scope module to provide nicer APIs. 109 | 110 | 111 | def _set_checkpoint_initializer(variable, file_pattern, tensor_name, slice_spec, 112 | name="checkpoint_initializer"): 113 | """Sets variable initializer to assign op form value in checkpoint's tensor. 114 | 115 | Args: 116 | variable: `Variable` object. 117 | file_pattern: string, where to load checkpoints from. 118 | tensor_name: Name of the `Tensor` to load from checkpoint reader. 119 | slice_spec: Slice specification for loading partitioned variables. 120 | name: Name of the operation. 121 | """ 122 | base_type = variable.dtype.base_dtype 123 | restore_op = gen_io_ops._restore_slice( 124 | file_pattern, 125 | tensor_name, 126 | slice_spec, 127 | base_type, 128 | preferred_shard=-1, 129 | name=name) 130 | variable._initializer_op = state_ops.assign(variable, restore_op) 131 | 132 | 133 | def _set_variable_or_list_initializer(variable_or_list, file_pattern, 134 | tensor_name): 135 | if isinstance(variable_or_list, (list, tuple)): 136 | # A set of slices. 137 | slice_name = None 138 | for v in variable_or_list: 139 | if slice_name is None: 140 | slice_name = v._save_slice_info.full_name 141 | elif slice_name != v._save_slice_info.full_name: 142 | raise ValueError("Slices must all be from the same tensor: %s != %s" % 143 | (slice_name, v._save_slice_info.full_name)) 144 | _set_checkpoint_initializer(v, file_pattern, tensor_name, 145 | v._save_slice_info.spec) 146 | else: 147 | _set_checkpoint_initializer(variable_or_list, file_pattern, tensor_name, "") 148 | 149 | 150 | def init_from_checkpoint(checkpoint_dir, assignment_map): 151 | """Using assingment map initializes current variables with loaded tensors. 152 | 153 | Note: This overrides default initialization ops of specified variables and 154 | redefines dtype. 155 | 156 | Assignment map supports following syntax: 157 | `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in 158 | current `scope_name` from `checkpoint_scope_name` with matching variable 159 | names. 160 | `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` - 161 | will initalize `scope_name/variable_name` variable 162 | from `checkpoint_scope_name/some_other_variable`. 163 | `'scope_variable_name': variable` - will initialize given `tf.Variable` 164 | object with variable from the checkpoint. 165 | `'scope_variable_name': list(variable)` - will initialize list of 166 | partitioned variables with variable from the checkpoint. 167 | `'scope_name/': '/'` - will load all variables in current `scope_name` from 168 | checkpoint's root (e.g. no scope). 169 | 170 | Supports loading into partitioned variables, which are represented as 171 | '/part_'. 172 | 173 | Example: 174 | ```python 175 | # Create variables. 176 | with tf.variable_scope('test'): 177 | m = tf.get_variable('my_var') 178 | with tf.variable_scope('test2'): 179 | var2 = tf.get_variable('my_var') 180 | var3 = tf.get_variable(name="my1", shape=[100, 100], 181 | partitioner=lambda shape, dtype: [5, 1]) 182 | ... 183 | # Specify which variables to intialize from checkpoint. 184 | init_from_checkpoint(checkpoint_dir, { 185 | 'some_var': 'test/my_var', 186 | 'some_scope/': 'test2/'}) 187 | ... 188 | # Or use `Variable` objects to identify what to initialize. 189 | init_from_checkpoint(checkpoint_dir, { 190 | 'some_scope/var2': var2, 191 | }) 192 | # Initialize partitioned variables 193 | init_from_checkpoint(checkpoint_dir, { 194 | 'some_var_from_ckpt': 'part_var', 195 | }) 196 | # Or specifying the list of `Variable` objects. 197 | init_from_checkpoint(checkpoint_dir, { 198 | 'some_var_from_ckpt': var3._get_variable_list(), 199 | }) 200 | ... 201 | # Initialize variables as usual. 202 | session.run(tf.get_all_variables()) 203 | ``` 204 | 205 | Args: 206 | checkpoint_dir: Directory with checkpoints file or path to checkpoint. 207 | assignment_map: Dict, where keys are names of the variables in the 208 | checkpoint and values are current variables or names of current variables 209 | (in default graph). 210 | 211 | Raises: 212 | tf.errors.OpError: If missing checkpoints or tensors in checkpoints. 213 | ValueError: If missing variables in current graph. 214 | """ 215 | filepattern = _get_checkpoint_filename(checkpoint_dir) 216 | reader = load_checkpoint(checkpoint_dir) 217 | variable_map = reader.get_variable_to_shape_map() 218 | for tensor_name_in_ckpt, current_var_or_name in six.iteritems(assignment_map): 219 | var = None 220 | # Check if this is Variable object or list of Variable objects (in case of 221 | # partitioned variables). 222 | is_var = lambda x: isinstance(x, variables.Variable) 223 | if is_var(current_var_or_name) or ( 224 | isinstance(current_var_or_name, list) 225 | and all(is_var(v) for v in current_var_or_name)): 226 | var = current_var_or_name 227 | else: 228 | var_scope = vs._get_default_variable_store() 229 | # Check if this variable is in var_store. 230 | var = var_scope._vars.get(current_var_or_name, None) 231 | # Also check if variable is partitioned as list. 232 | if var is None: 233 | if current_var_or_name + "/part_0" in var_scope._vars: 234 | var = [] 235 | i = 0 236 | while current_var_or_name + "/part_%d" % i in var_scope._vars: 237 | var.append(var_scope._vars[current_var_or_name + "/part_%d" % i]) 238 | i += 1 239 | if var is not None: 240 | # If 1 to 1 mapping was provided, find variable in the checkpoint. 241 | if tensor_name_in_ckpt not in variable_map: 242 | raise ValueError("Tensor %s is not found in %s checkpoint" % ( 243 | tensor_name_in_ckpt, checkpoint_dir 244 | )) 245 | if is_var(var): 246 | # Additional at-call-time checks. 247 | if not var.get_shape().is_compatible_with( 248 | variable_map[tensor_name_in_ckpt]): 249 | raise ValueError( 250 | "Shape of variable %s (%s) doesn't match with shape of " 251 | "tensor %s (%s) from checkpoint reader." % ( 252 | var.name, str(var.get_shape()), 253 | tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt]) 254 | )) 255 | var_name = var.name 256 | else: 257 | var_name = ",".join([v.name for v in var]) 258 | _set_variable_or_list_initializer(var, filepattern, tensor_name_in_ckpt) 259 | logging.info("Initialize variable %s from checkpoint %s with %s" % ( 260 | var_name, checkpoint_dir, tensor_name_in_ckpt 261 | )) 262 | else: 263 | scopes = "" 264 | # TODO(vihanjain): Support list of 'current_var_or_name' here. 265 | if "/" in current_var_or_name: 266 | scopes = current_var_or_name[:current_var_or_name.rindex("/")] 267 | if not tensor_name_in_ckpt.endswith("/"): 268 | raise ValueError( 269 | "Assignment map with scope only name {} should map to scope only " 270 | "{}. Should be 'scope/': 'other_scope/'.".format( 271 | scopes, tensor_name_in_ckpt)) 272 | # If scope to scope mapping was provided, find all variables in the scope. 273 | for var_name in var_scope._vars: 274 | if var_name.startswith(scopes): 275 | # Lookup name with specified prefix and suffix from current variable. 276 | # If tensor_name given is '/' (root), don't use it for full name. 277 | if tensor_name_in_ckpt != "/": 278 | full_tensor_name = tensor_name_in_ckpt + var_name[len(scopes) + 1:] 279 | else: 280 | full_tensor_name = var_name[len(scopes) + 1:] 281 | if full_tensor_name not in variable_map: 282 | raise ValueError( 283 | "Tensor %s (%s in %s) is not found in %s checkpoint" % ( 284 | full_tensor_name, var_name[len(scopes) + 1:], 285 | tensor_name_in_ckpt, checkpoint_dir 286 | )) 287 | var = var_scope._vars[var_name] 288 | _set_variable_or_list_initializer(var, filepattern, full_tensor_name) 289 | logging.info("Initialize variable %s from checkpoint %s with %s" % ( 290 | var_name, checkpoint_dir, full_tensor_name 291 | )) 292 | # pylint: enable=protected-access -------------------------------------------------------------------------------- /tools/freeze_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Converts checkpoint variables into Const ops in a standalone GraphDef file. 16 | 17 | This script is designed to take a GraphDef proto, a SaverDef proto, and a set of 18 | variable values stored in a checkpoint file, and output a GraphDef with all of 19 | the variable ops converted into const ops containing the values of the 20 | variables. 21 | 22 | It's useful to do this when we need to load a single file in C++, especially in 23 | environments like mobile or embedded where we may not have access to the 24 | RestoreTensor ops and file loading calls that they rely on. 25 | 26 | An example of command-line usage is: 27 | bazel build tensorflow/python/tools:freeze_graph && \ 28 | bazel-bin/tensorflow/python/tools/freeze_graph \ 29 | --input_graph=some_graph_def.pb \ 30 | --input_checkpoint=model.ckpt-8361242 \ 31 | --output_graph=/tmp/frozen_graph.pb --output_node_names=softmax 32 | 33 | You can also look at freeze_graph_test.py for an example of how to use it. 34 | 35 | """ 36 | from __future__ import absolute_import 37 | from __future__ import division 38 | from __future__ import print_function 39 | 40 | import tensorflow as tf 41 | 42 | from google.protobuf import text_format 43 | from tensorflow.python.framework import graph_util 44 | 45 | 46 | FLAGS = tf.app.flags.FLAGS 47 | 48 | tf.app.flags.DEFINE_string("input_graph", "", 49 | """TensorFlow 'GraphDef' file to load.""") 50 | tf.app.flags.DEFINE_string("input_saver", "", 51 | """TensorFlow saver file to load.""") 52 | tf.app.flags.DEFINE_string("input_checkpoint", "", 53 | """TensorFlow variables file to load.""") 54 | tf.app.flags.DEFINE_string("output_graph", "", 55 | """Output 'GraphDef' file name.""") 56 | tf.app.flags.DEFINE_boolean("input_binary", False, 57 | """Whether the input files are in binary format.""") 58 | tf.app.flags.DEFINE_string("output_node_names", "", 59 | """The name of the output nodes, comma separated.""") 60 | tf.app.flags.DEFINE_string("restore_op_name", "save/restore_all", 61 | """The name of the master restore operator.""") 62 | tf.app.flags.DEFINE_string("filename_tensor_name", "save/Const:0", 63 | """The name of the tensor holding the save path.""") 64 | tf.app.flags.DEFINE_boolean("clear_devices", True, 65 | """Whether to remove device specifications.""") 66 | tf.app.flags.DEFINE_string("initializer_nodes", "", "comma separated list of " 67 | "initializer nodes to run before freezing.") 68 | 69 | 70 | def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, 71 | output_node_names, restore_op_name, filename_tensor_name, 72 | output_graph, clear_devices, initializer_nodes): 73 | """Converts all variables in a graph and checkpoint into constants.""" 74 | 75 | if not tf.gfile.Exists(input_graph): 76 | print("Input graph file '" + input_graph + "' does not exist!") 77 | return -1 78 | 79 | if input_saver and not tf.gfile.Exists(input_saver): 80 | print("Input saver file '" + input_saver + "' does not exist!") 81 | return -1 82 | 83 | if not tf.gfile.Glob(input_checkpoint): 84 | print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") 85 | return -1 86 | 87 | if not output_node_names: 88 | print("You need to supply the name of a node to --output_node_names.") 89 | return -1 90 | 91 | input_graph_def = tf.GraphDef() 92 | mode = "rb" if input_binary else "r" 93 | with tf.gfile.FastGFile(input_graph, mode) as f: 94 | if input_binary: 95 | input_graph_def.ParseFromString(f.read()) 96 | else: 97 | text_format.Merge(f.read(), input_graph_def) 98 | # Remove all the explicit device specifications for this node. This helps to 99 | # make the graph more portable. 100 | if clear_devices: 101 | for node in input_graph_def.node: 102 | node.device = "" 103 | _ = tf.import_graph_def(input_graph_def, name="") 104 | 105 | with tf.Session() as sess: 106 | if input_saver: 107 | with tf.gfile.FastGFile(input_saver, mode) as f: 108 | saver_def = tf.train.SaverDef() 109 | if input_binary: 110 | saver_def.ParseFromString(f.read()) 111 | else: 112 | text_format.Merge(f.read(), saver_def) 113 | saver = tf.train.Saver(saver_def=saver_def) 114 | saver.restore(sess, input_checkpoint) 115 | else: 116 | sess.run([restore_op_name], {filename_tensor_name: input_checkpoint}) 117 | if initializer_nodes: 118 | sess.run(initializer_nodes) 119 | output_graph_def = graph_util.convert_variables_to_constants( 120 | sess, input_graph_def, output_node_names.split(",")) 121 | 122 | with tf.gfile.GFile(output_graph, "wb") as f: 123 | f.write(output_graph_def.SerializeToString()) 124 | print("%d ops in the final graph." % len(output_graph_def.node)) 125 | 126 | 127 | def main(unused_args): 128 | freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary, 129 | FLAGS.input_checkpoint, FLAGS.output_node_names, 130 | FLAGS.restore_op_name, FLAGS.filename_tensor_name, 131 | FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes) 132 | 133 | if __name__ == "__main__": 134 | tf.app.run() 135 | -------------------------------------------------------------------------------- /tools/freeze_graph_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests the graph freezing tool.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | 22 | import tensorflow as tf 23 | 24 | from tensorflow.python.framework import test_util 25 | from tensorflow.python.tools import freeze_graph 26 | 27 | 28 | class FreezeGraphTest(test_util.TensorFlowTestCase): 29 | 30 | def testFreezeGraph(self): 31 | 32 | checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint") 33 | checkpoint_state_name = "checkpoint_state" 34 | input_graph_name = "input_graph.pb" 35 | output_graph_name = "output_graph.pb" 36 | 37 | # We'll create an input graph that has a single variable containing 1.0, 38 | # and that then multiplies it by 2. 39 | with tf.Graph().as_default(): 40 | variable_node = tf.Variable(1.0, name="variable_node") 41 | output_node = tf.mul(variable_node, 2.0, name="output_node") 42 | sess = tf.Session() 43 | init = tf.initialize_all_variables() 44 | sess.run(init) 45 | output = sess.run(output_node) 46 | self.assertNear(2.0, output, 0.00001) 47 | saver = tf.train.Saver() 48 | saver.save(sess, checkpoint_prefix, global_step=0, 49 | latest_filename=checkpoint_state_name) 50 | tf.train.write_graph(sess.graph.as_graph_def(), self.get_temp_dir(), 51 | input_graph_name) 52 | 53 | # We save out the graph to disk, and then call the const conversion 54 | # routine. 55 | input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name) 56 | input_saver_def_path = "" 57 | input_binary = False 58 | input_checkpoint_path = checkpoint_prefix + "-0" 59 | output_node_names = "output_node" 60 | restore_op_name = "save/restore_all" 61 | filename_tensor_name = "save/Const:0" 62 | output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name) 63 | clear_devices = False 64 | 65 | freeze_graph.freeze_graph(input_graph_path, input_saver_def_path, 66 | input_binary, input_checkpoint_path, 67 | output_node_names, restore_op_name, 68 | filename_tensor_name, output_graph_path, 69 | clear_devices, "") 70 | 71 | # Now we make sure the variable is now a constant, and that the graph still 72 | # produces the expected result. 73 | with tf.Graph().as_default(): 74 | output_graph_def = tf.GraphDef() 75 | with open(output_graph_path, "rb") as f: 76 | output_graph_def.ParseFromString(f.read()) 77 | _ = tf.import_graph_def(output_graph_def, name="") 78 | 79 | self.assertEqual(4, len(output_graph_def.node)) 80 | for node in output_graph_def.node: 81 | self.assertNotEqual("Variable", node.op) 82 | 83 | with tf.Session() as sess: 84 | output_node = sess.graph.get_tensor_by_name("output_node:0") 85 | output = sess.run(output_node) 86 | self.assertNear(2.0, output, 0.00001) 87 | 88 | if __name__ == "__main__": 89 | tf.test.main() 90 | -------------------------------------------------------------------------------- /tools/graph_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Gives estimates of computation and parameter sizes for a GraphDef. 17 | 18 | This script takes a GraphDef representing a network, and produces rough 19 | estimates of the number of floating-point operations needed to implement it and 20 | how many parameters are stored. You need to pass in the input size, and the 21 | results are only approximate, since it only calculates them for a subset of 22 | common operations. 23 | 24 | If you have downloaded the Inception graph for the label_image example, an 25 | example of using this script would be: 26 | 27 | bazel-bin/third_party/tensorflow/python/tools/graph_metrics \ 28 | --graph tensorflow_inception_graph.pb \ 29 | --statistics=weight_parameters,flops 30 | 31 | """ 32 | from __future__ import absolute_import 33 | from __future__ import division 34 | from __future__ import print_function 35 | 36 | import locale 37 | 38 | import tensorflow as tf 39 | 40 | from google.protobuf import text_format 41 | 42 | from tensorflow.core.framework import graph_pb2 43 | from tensorflow.python.framework import ops 44 | 45 | 46 | FLAGS = tf.flags.FLAGS 47 | 48 | tf.flags.DEFINE_string("graph", "", """TensorFlow 'GraphDef' file to load.""") 49 | tf.flags.DEFINE_bool("input_binary", True, 50 | """Whether the input files are in binary format.""") 51 | tf.flags.DEFINE_string("input_layer", "Mul:0", 52 | """The name of the input node.""") 53 | tf.flags.DEFINE_integer("batch_size", 1, 54 | """The batch size to use for the calculations.""") 55 | tf.flags.DEFINE_string("statistics", "weight_parameters,flops", 56 | """Which statistic types to examine.""") 57 | tf.flags.DEFINE_string("input_shape_override", "", 58 | """If this is set, the comma-separated values will be""" 59 | """ used to set the shape of the input layer.""") 60 | tf.flags.DEFINE_boolean("print_nodes", False, 61 | """Whether to show statistics for each op.""") 62 | 63 | 64 | def print_stat(prefix, statistic_type, value): 65 | if value is None: 66 | friendly_value = "None" 67 | else: 68 | friendly_value = locale.format("%d", value, grouping=True) 69 | print("%s%s=%s" % (prefix, statistic_type, friendly_value)) 70 | 71 | 72 | def main(unused_args): 73 | if not tf.gfile.Exists(FLAGS.graph): 74 | print("Input graph file '" + FLAGS.graph + "' does not exist!") 75 | return -1 76 | graph_def = graph_pb2.GraphDef() 77 | with open(FLAGS.graph, "rb") as f: 78 | if FLAGS.input_binary: 79 | graph_def.ParseFromString(f.read()) 80 | else: 81 | text_format.Merge(f.read(), graph_def) 82 | statistic_types = FLAGS.statistics.split(",") 83 | if FLAGS.input_shape_override: 84 | input_shape_override = map(int, FLAGS.input_shape_override.split(",")) 85 | else: 86 | input_shape_override = None 87 | total_stats, node_stats = calculate_graph_metrics( 88 | graph_def, statistic_types, FLAGS.input_layer, input_shape_override, 89 | FLAGS.batch_size) 90 | if FLAGS.print_nodes: 91 | for node in graph_def.node: 92 | for statistic_type in statistic_types: 93 | current_stats = node_stats[statistic_type][node.name] 94 | print_stat(node.name + "(" + node.op + "): ", statistic_type, 95 | current_stats.value) 96 | for statistic_type in statistic_types: 97 | value = total_stats[statistic_type].value 98 | print_stat("Total: ", statistic_type, value) 99 | 100 | 101 | def calculate_graph_metrics(graph_def, statistic_types, input_layer, 102 | input_shape_override, batch_size): 103 | """Looks at the performance statistics of all nodes in the graph.""" 104 | _ = tf.import_graph_def(graph_def, name="") 105 | total_stats = {} 106 | node_stats = {} 107 | for statistic_type in statistic_types: 108 | total_stats[statistic_type] = ops.OpStats(statistic_type) 109 | node_stats[statistic_type] = {} 110 | # Make sure we get pretty-printed numbers with separators. 111 | locale.setlocale(locale.LC_ALL, "") 112 | with tf.Session() as sess: 113 | input_tensor = sess.graph.get_tensor_by_name(input_layer) 114 | input_shape_tensor = input_tensor.get_shape() 115 | if input_shape_tensor: 116 | input_shape = input_shape_tensor.as_list() 117 | else: 118 | input_shape = None 119 | if input_shape_override: 120 | input_shape = input_shape_override 121 | if input_shape is None: 122 | raise ValueError("""No input shape was provided on the command line,""" 123 | """ and the input op itself had no default shape, so""" 124 | """ shape inference couldn't be performed. This is""" 125 | """ required for metrics calculations.""") 126 | input_shape[0] = batch_size 127 | input_tensor.set_shape(input_shape) 128 | for node in graph_def.node: 129 | # Ensure that the updated input shape has been fully-propagated before we 130 | # ask for the statistics, since they may depend on the output size. 131 | op = sess.graph.get_operation_by_name(node.name) 132 | ops.set_shapes_for_outputs(op) 133 | for statistic_type in statistic_types: 134 | current_stats = ops.get_stats_for_node_def(sess.graph, node, 135 | statistic_type) 136 | node_stats[statistic_type][node.name] = current_stats 137 | total_stats[statistic_type] += current_stats 138 | return total_stats, node_stats 139 | 140 | if __name__ == "__main__": 141 | tf.app.run() 142 | -------------------------------------------------------------------------------- /tools/graph_metrics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests the graph metrics tool.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from tensorflow.python.tools import graph_metrics 23 | 24 | 25 | class GraphMetricsTest(tf.test.TestCase): 26 | 27 | def testGraphMetrics(self): 28 | with tf.Graph().as_default(): 29 | input_node = tf.placeholder(tf.float32, shape=[10, 20], name="input_node") 30 | weights_node = tf.constant(0.0, 31 | dtype=tf.float32, 32 | shape=[20, 5], 33 | name="weights_node") 34 | tf.matmul(input_node, weights_node, name="matmul_node") 35 | sess = tf.Session() 36 | graph_def = sess.graph.as_graph_def() 37 | statistic_types = ["weight_parameters", "flops"] 38 | total_stats, node_stats = graph_metrics.calculate_graph_metrics( 39 | graph_def, statistic_types, "input_node:0", None, 10) 40 | expected = {"weight_parameters": 100, "flops": 2000} 41 | for statistic_type in statistic_types: 42 | current_stats = node_stats[statistic_type]["matmul_node"] 43 | self.assertEqual(expected[statistic_type], current_stats.value) 44 | for statistic_type in statistic_types: 45 | current_stats = total_stats[statistic_type] 46 | self.assertEqual(expected[statistic_type], current_stats.value) 47 | 48 | 49 | if __name__ == "__main__": 50 | tf.test.main() 51 | -------------------------------------------------------------------------------- /tools/inspect_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A simple script for inspect checkpoint files.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import sys 22 | 23 | import tensorflow as tf 24 | 25 | FLAGS = tf.app.flags.FLAGS 26 | tf.app.flags.DEFINE_string("file_name", './../tmp/doom_bs__act|sigmoid__bs|20__h|500|5|500__init|na__inp|cbd4__lr|0.0004__opt|AO/-3560', "Checkpoint filename") 27 | tf.app.flags.DEFINE_string("tensor_name", "", "Name of the tensor to inspect") 28 | 29 | 30 | def print_tensors_in_checkpoint_file(file_name, tensor_name): 31 | """Prints tensors in a checkpoint file. 32 | 33 | If no `tensor_name` is provided, prints the tensor names and shapes 34 | in the checkpoint file. 35 | 36 | If `tensor_name` is provided, prints the content of the tensor. 37 | 38 | Args: 39 | file_name: Name of the checkpoint file. 40 | tensor_name: Name of the tensor in the checkpoint file to print. 41 | """ 42 | try: 43 | reader = tf.train.NewCheckpointReader(file_name) 44 | if not tensor_name: 45 | print(reader.debug_string().decode("utf-8")) 46 | else: 47 | print("tensor_name: ", tensor_name) 48 | print(reader.get_tensor(tensor_name)) 49 | except Exception as e: # pylint: disable=broad-except 50 | print(str(e)) 51 | if "corrupted compressed block contents" in str(e): 52 | print("It's likely that your checkpoint file has been compressed " 53 | "with SNAPPY.") 54 | 55 | 56 | def main(unused_argv): 57 | if not FLAGS.file_name: 58 | print(FLAGS.file_name) 59 | print("Usage: inspect_checkpoint --file_name=checkpoint_file_name " 60 | "[--tensor_name=tensor_to_print]") 61 | sys.exit(1) 62 | else: 63 | print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name) 64 | 65 | if __name__ == "__main__": 66 | 67 | tf.app.run() 68 | -------------------------------------------------------------------------------- /tools/strip_unused.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Removes unneeded nodes from a GraphDef file. 16 | 17 | This script is designed to help streamline models, by taking the input and 18 | output nodes that will be used by an application and figuring out the smallest 19 | set of operations that are required to run for those arguments. The resulting 20 | minimal graph is then saved out. 21 | 22 | The advantages of running this script are: 23 | - You may be able to shrink the file size. 24 | - Operations that are unsupported on your platform but still present can be 25 | safely removed. 26 | The resulting graph may not be as flexible as the original though, since any 27 | input nodes that weren't explicitly mentioned may not be accessible any more. 28 | 29 | An example of command-line usage is: 30 | bazel build tensorflow/python/tools:strip_unused && \ 31 | bazel-bin/tensorflow/python/tools/strip_unused \ 32 | --input_graph=some_graph_def.pb \ 33 | --output_graph=/tmp/stripped_graph.pb \ 34 | --input_node_names=input0 35 | --output_node_names=softmax 36 | 37 | You can also look at strip_unused_test.py for an example of how to use it. 38 | 39 | """ 40 | from __future__ import absolute_import 41 | from __future__ import division 42 | from __future__ import print_function 43 | import copy 44 | 45 | import tensorflow as tf 46 | 47 | from google.protobuf import text_format 48 | from tensorflow.python.framework import graph_util 49 | 50 | 51 | FLAGS = tf.app.flags.FLAGS 52 | 53 | tf.app.flags.DEFINE_string("input_graph", "", 54 | """TensorFlow 'GraphDef' file to load.""") 55 | tf.app.flags.DEFINE_boolean("input_binary", False, 56 | """Whether the input files are in binary format.""") 57 | tf.app.flags.DEFINE_string("output_graph", "", 58 | """Output 'GraphDef' file name.""") 59 | tf.app.flags.DEFINE_string("input_node_names", "", 60 | """The name of the input nodes, comma separated.""") 61 | tf.app.flags.DEFINE_string("output_node_names", "", 62 | """The name of the output nodes, comma separated.""") 63 | tf.app.flags.DEFINE_integer("placeholder_type_enum", 64 | tf.float32.as_datatype_enum, 65 | """The AttrValue enum to use for placeholders.""") 66 | 67 | 68 | def strip_unused(input_graph, input_binary, output_graph, input_node_names, 69 | output_node_names, placeholder_type_enum): 70 | """Removes unused nodes from a graph.""" 71 | 72 | if not tf.gfile.Exists(input_graph): 73 | print("Input graph file '" + input_graph + "' does not exist!") 74 | return -1 75 | 76 | if not output_node_names: 77 | print("You need to supply the name of a node to --output_node_names.") 78 | return -1 79 | 80 | input_graph_def = tf.GraphDef() 81 | mode = "rb" if input_binary else "r" 82 | with tf.gfile.FastGFile(input_graph, mode) as f: 83 | if input_binary: 84 | input_graph_def.ParseFromString(f.read()) 85 | else: 86 | text_format.Merge(f.read(), input_graph_def) 87 | 88 | # Here we replace the nodes we're going to override as inputs with 89 | # placeholders so that any unused nodes that are inputs to them are 90 | # automatically stripped out by extract_sub_graph(). 91 | input_node_names_list = input_node_names.split(",") 92 | inputs_replaced_graph_def = tf.GraphDef() 93 | for node in input_graph_def.node: 94 | if node.name in input_node_names_list: 95 | placeholder_node = tf.NodeDef() 96 | placeholder_node.op = "Placeholder" 97 | placeholder_node.name = node.name 98 | placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue( 99 | type=placeholder_type_enum)) 100 | inputs_replaced_graph_def.node.extend([placeholder_node]) 101 | else: 102 | inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) 103 | 104 | output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, 105 | output_node_names.split(",")) 106 | 107 | with tf.gfile.GFile(output_graph, "wb") as f: 108 | f.write(output_graph_def.SerializeToString()) 109 | print("%d ops in the final graph." % len(output_graph_def.node)) 110 | 111 | 112 | def main(unused_args): 113 | strip_unused(FLAGS.input_graph, FLAGS.input_binary, FLAGS.output_graph, 114 | FLAGS.input_node_names, FLAGS.output_node_names, 115 | FLAGS.placeholder_type_enum) 116 | 117 | if __name__ == "__main__": 118 | tf.app.run() 119 | -------------------------------------------------------------------------------- /tools/strip_unused_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests the graph freezing tool.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | 22 | import tensorflow as tf 23 | 24 | from tensorflow.python.framework import test_util 25 | from tensorflow.python.tools import strip_unused 26 | 27 | 28 | class FreezeGraphTest(test_util.TensorFlowTestCase): 29 | 30 | def testFreezeGraph(self): 31 | input_graph_name = "input_graph.pb" 32 | output_graph_name = "output_graph.pb" 33 | 34 | # We'll create an input graph that has a single constant containing 1.0, 35 | # and that then multiplies it by 2. 36 | with tf.Graph().as_default(): 37 | constant_node = tf.constant(1.0, name="constant_node") 38 | wanted_input_node = tf.sub(constant_node, 3.0, name="wanted_input_node") 39 | output_node = tf.mul(wanted_input_node, 2.0, name="output_node") 40 | tf.add(output_node, 2.0, name="later_node") 41 | sess = tf.Session() 42 | output = sess.run(output_node) 43 | self.assertNear(-4.0, output, 0.00001) 44 | tf.train.write_graph(sess.graph.as_graph_def(), self.get_temp_dir(), 45 | input_graph_name) 46 | 47 | # We save out the graph to disk, and then call the const conversion 48 | # routine. 49 | input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name) 50 | input_binary = False 51 | input_node_names = "wanted_input_node" 52 | output_node_names = "output_node" 53 | output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name) 54 | 55 | strip_unused.strip_unused(input_graph_path, input_binary, output_graph_path, 56 | input_node_names, output_node_names, 57 | tf.float32.as_datatype_enum) 58 | 59 | # Now we make sure the variable is now a constant, and that the graph still 60 | # produces the expected result. 61 | with tf.Graph().as_default(): 62 | output_graph_def = tf.GraphDef() 63 | with open(output_graph_path, "rb") as f: 64 | output_graph_def.ParseFromString(f.read()) 65 | _ = tf.import_graph_def(output_graph_def, name="") 66 | 67 | self.assertEqual(3, len(output_graph_def.node)) 68 | for node in output_graph_def.node: 69 | self.assertNotEqual("Add", node.op) 70 | self.assertNotEqual("Sub", node.op) 71 | 72 | with tf.Session() as sess: 73 | input_node = sess.graph.get_tensor_by_name("wanted_input_node:0") 74 | output_node = sess.graph.get_tensor_by_name("output_node:0") 75 | output = sess.run(output_node, feed_dict={input_node: [10.0]}) 76 | self.assertNear(20.0, output, 0.00001) 77 | 78 | if __name__ == "__main__": 79 | tf.test.main() 80 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import time 3 | from matplotlib import pyplot as plt 4 | import os 5 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 6 | import collections 7 | import tensorflow as tf 8 | from scipy import misc 9 | import re 10 | import sys 11 | import subprocess as sp 12 | import warnings 13 | import functools 14 | import scipy 15 | import random 16 | 17 | 18 | 19 | IMAGE_FOLDER = './img/' 20 | TEMP_FOLDER = './tmp/' 21 | EPOCH_THRESHOLD = 4 22 | FLAGS = tf.app.flags.FLAGS 23 | 24 | _start_time = None 25 | 26 | 27 | # CONSOLE OPERATIONS 28 | 29 | 30 | def reset_start_time(): 31 | global _start_time 32 | _start_time = None 33 | 34 | 35 | def _get_time_offset(): 36 | global _start_time 37 | time = datetime.datetime.now() 38 | if _start_time is None: 39 | _start_time = time 40 | return '\t\t' 41 | sec = (time - _start_time).total_seconds() 42 | res = '(+%d)\t' % sec if sec < 60 else '(+%d:%02d)\t' % (sec/60, sec%60) 43 | return res 44 | 45 | 46 | def print_time(*args, same_line=False): 47 | string = '' 48 | for a in args: 49 | string += str(a) + ' ' 50 | time = datetime.datetime.now().time().strftime('%H:%M:%S') 51 | offset = _get_time_offset() 52 | res = '%s%s %s' % (str(time), offset, str(string)) 53 | print_color(res, same_line=same_line) 54 | 55 | 56 | def print_info(string, color=32, same_line=False): 57 | print_color('\t' + str(string), color=color, same_line=same_line) 58 | 59 | 60 | same_line_prev = None 61 | 62 | 63 | def print_color(string, color=33, same_line=False): 64 | global same_line_prev 65 | res = '%c[1;%dm%s%c[0m' % (27, color, str(string), 27) 66 | if same_line: 67 | print('\r ' + 68 | ' ', end=' ') 69 | print('\r' + res, end=' ') 70 | else: 71 | # if same_line_prev: 72 | # print('\n') 73 | print(res) 74 | same_line_prev = same_line 75 | 76 | 77 | def mnist_select_n_classes(train_images, train_labels, num_classes, min=None, scale=1.0): 78 | result_images, result_labels = [], [] 79 | for i, j in zip(train_images, train_labels): 80 | if np.sum(j[0:num_classes]) > 0: 81 | result_images.append(i) 82 | result_labels.append(j[0:num_classes]) 83 | inputs = np.asarray(result_images) 84 | 85 | inputs *= scale 86 | if min is not None: 87 | inputs = inputs - np.min(inputs) + min 88 | return inputs, np.asarray(result_labels) 89 | 90 | 91 | # IMAGE OPERATIONS 92 | 93 | 94 | def _save_image(name='image', save_params=None, image=None): 95 | if save_params is not None and 'e' in save_params and save_params['e'] < EPOCH_THRESHOLD: 96 | print_info('IMAGE: output is not saved. epochs %d < %d' % (save_params['e'], EPOCH_THRESHOLD), color=31) 97 | return 98 | 99 | file_name = name if save_params is None else to_file_name(save_params) 100 | file_name += '.png' 101 | name = os.path.join(FLAGS.save_path, file_name) 102 | 103 | if image is not None: 104 | misc.imsave(name, arr=image, format='png') 105 | 106 | 107 | def _show_picture(pic): 108 | fig = plt.figure() 109 | size = fig.get_size_inches() 110 | fig.set_size_inches(size[0], size[1] * 2, forward=True) 111 | plt.imshow(pic, cmap='Greys_r') 112 | 113 | 114 | def concat_images(im1, im2, axis=0): 115 | if im1 is None: 116 | return im2 117 | return np.concatenate((im1, im2), axis=axis) 118 | 119 | 120 | def _reconstruct_picture_line(pictures, shape): 121 | line_picture = None 122 | for _, img in enumerate(pictures): 123 | if len(img.shape) == 1: 124 | img = (np.reshape(img, shape)) 125 | if len(img.shape) == 3 and img.shape[2] == 1: 126 | img = (np.reshape(img, (img.shape[0], img.shape[1]))) 127 | line_picture = concat_images(line_picture, img) 128 | return line_picture 129 | 130 | 131 | def show_plt(): 132 | plt.show() 133 | 134 | 135 | def _construct_img_shape(img): 136 | assert int(np.sqrt(img.shape[0])) == np.sqrt(img.shape[0]) 137 | return int(np.sqrt(img.shape[0])), int(np.sqrt(img.shape[0])), 1 138 | 139 | 140 | def images_to_uint8(func): 141 | def normalize(arr): 142 | # if type(arr) == np.ndarray and arr.dtype != np.uint8 and len(arr.shape) >= 3: 143 | if type(arr) == np.ndarray and len(arr.shape) >= 3: 144 | if np.min(arr) < 0: 145 | print('image array normalization: negative values') 146 | if np.max(arr) < 10: 147 | arr *= 255 148 | if arr.shape[-1] == 4 or arr.shape[-1] == 2: 149 | old_shape = arr.shape 150 | arr = arr[..., :arr.shape[-1]-1] 151 | return arr.astype(np.uint8) 152 | return arr 153 | 154 | def func_wrapper(*args, **kwargs): 155 | new_args = [normalize(el) for el in args] 156 | new_kwargs = {k: normalize(kwargs[k]) for _, k in enumerate(kwargs)} 157 | return func(*tuple(new_args), **new_kwargs) 158 | return func_wrapper 159 | 160 | 161 | def fig2buf(fig): 162 | fig.canvas.draw() 163 | return fig.canvas.tostring_rgb() 164 | 165 | 166 | def fig2rgb_array(fig, expand=True): 167 | fig.canvas.draw() 168 | buf = fig.canvas.tostring_rgb() 169 | ncols, nrows = fig.canvas.get_width_height() 170 | shape = (nrows, ncols, 3) if not expand else (1, nrows, ncols, 3) 171 | return np.fromstring(buf, dtype=np.uint8).reshape(shape) 172 | 173 | 174 | @images_to_uint8 175 | def reconstruct_images_epochs(epochs, original=None, save_params=None, img_shape=None): 176 | full_picture = None 177 | img_shape = img_shape if img_shape is not None else _construct_img_shape(epochs[0][0]) 178 | 179 | # print(original.dtype, epochs.dtype, np.max(original), np.max(epochs)) 180 | 181 | if original.dtype != np.uint8: 182 | original = (original * 255).astype(np.uint8) 183 | if epochs.dtype != np.uint8: 184 | epochs = (epochs * 255).astype(np.uint8) 185 | 186 | # print('image reconstruction: ', original.dtype, epochs.dtype, np.max(original), np.max(epochs)) 187 | 188 | if original is not None and epochs is not None and len(epochs) >= 3: 189 | min_ref, max_ref = np.min(original), np.max(original) 190 | print_info('epoch avg: (original: %s) -> %s' % ( 191 | str(np.mean(original)), str((np.mean(epochs[0]), np.mean(epochs[1]), np.mean(epochs[2]))))) 192 | print_info('reconstruction char. in epochs (min, max)|original: (%f %f)|(%f %f)' % ( 193 | np.min(epochs[1:]), np.max(epochs), min_ref, max_ref)) 194 | 195 | if epochs is not None: 196 | for _, epoch in enumerate(epochs): 197 | full_picture = concat_images(full_picture, _reconstruct_picture_line(epoch, img_shape), axis=1) 198 | if original is not None: 199 | full_picture = concat_images(full_picture, _reconstruct_picture_line(original, img_shape), axis=1) 200 | _show_picture(full_picture) 201 | _save_image(save_params=save_params, image=full_picture) 202 | 203 | 204 | def model_to_file_name(FLAGS, folder=None, ext=None): 205 | postfix = '' if len(FLAGS.postfix) == 0 else '_%s' % FLAGS.postfix 206 | name = '%s.%s__i_%s%s' % (FLAGS.model, FLAGS.net.replace('-', '_'), FLAGS.input_name, postfix) 207 | if ext: 208 | name += '.' + ext 209 | if folder: 210 | name = os.path.join(folder, name) 211 | return name 212 | 213 | 214 | def mkdir(folders): 215 | if isinstance(folders, str): 216 | folders = [folders] 217 | for _, folder in enumerate(folders): 218 | if not os.path.exists(folder): 219 | os.mkdir(folder) 220 | 221 | 222 | def configure_folders(FLAGS): 223 | folder_name = model_to_file_name(FLAGS) + '/' 224 | FLAGS.save_path = os.path.join(TEMP_FOLDER, folder_name) 225 | FLAGS.logdir = FLAGS.save_path 226 | print_color(os.path.abspath(FLAGS.logdir)) 227 | mkdir([TEMP_FOLDER, IMAGE_FOLDER, FLAGS.save_path, FLAGS.logdir]) 228 | 229 | with open(os.path.join(FLAGS.save_path, '!note.txt'), "a") as f: 230 | f.write('\n' + ' '.join(sys.argv) + '\n') 231 | f.write(print_flags(FLAGS, print=False)) 232 | if len(FLAGS.comment) > 0: 233 | f.write('\n\n%s\n' % FLAGS.comment) 234 | 235 | 236 | def get_files(folder="./visualizations/", filter=None): 237 | all = [] 238 | for root, dirs, files in os.walk(folder): 239 | # print(root, dirs, files) 240 | if filter: 241 | files = [x for x in files if re.match(filter, x)] 242 | all += files 243 | return [os.path.join(folder, x) for x in all] 244 | 245 | 246 | def get_latest_file(folder="./visualizations/", filter=None): 247 | latest_file, latest_mod_time = None, None 248 | for root, dirs, files in os.walk(folder): 249 | # print(root, dirs, files) 250 | if filter: 251 | files = [x for x in files if re.match(filter, x)] 252 | # print('\n\r'.join(files)) 253 | for file in files: 254 | file_path = os.path.join(root, file) 255 | modification_time = os.path.getmtime(file_path) 256 | if not latest_mod_time or modification_time > latest_mod_time: 257 | latest_mod_time = modification_time 258 | latest_file = file_path 259 | if latest_file is None: 260 | print_info('Could not find file matching %s' % str(filter)) 261 | return latest_file 262 | 263 | 264 | def concatenate(x, y, take=None): 265 | """ 266 | Stitches two np arrays together until maximum length, when specified 267 | """ 268 | if take is not None and x is not None and len(x) >= take: 269 | return x 270 | if x is None: 271 | res = y 272 | else: 273 | res = np.concatenate((x, y)) 274 | return res[:take] if take is not None else res 275 | 276 | 277 | # MISC 278 | 279 | 280 | def list_object_attributes(obj): 281 | print('Object type: %s\t\tattributes:' % str(type(obj))) 282 | print('\n\t'.join(map(str, obj.__dict__.keys()))) 283 | 284 | 285 | def print_list(list): 286 | print('\n'.join(map(str, list))) 287 | 288 | 289 | def print_float_list(list, format='%.4f'): 290 | return ' '.join(map(lambda x: format%x, list)) 291 | 292 | 293 | def timeit(method): 294 | def timed(*args, **kw): 295 | ts = time.time() 296 | result = method(*args, **kw) 297 | te = time.time() 298 | print('%r %2.2f sec' % (method.__name__, te-ts)) 299 | return result 300 | return timed 301 | 302 | 303 | def deprecated(func): 304 | '''This is a decorator which can be used to mark functions 305 | as deprecated. It will result in a warning being emitted 306 | when the function is used.''' 307 | 308 | @functools.wraps(func) 309 | def new_func(*args, **kwargs): 310 | warnings.warn_explicit( 311 | "Call to deprecated function {}.".format(func.__name__), 312 | category=DeprecationWarning, 313 | filename=func.func_code.co_filename, 314 | lineno=func.func_code.co_firstlineno + 1 315 | ) 316 | return func(*args, **kwargs) 317 | 318 | return new_func 319 | 320 | 321 | import numpy as np 322 | ACCEPTABLE_AVAILABLE_MEMORY = 1024 323 | 324 | 325 | def mask_busy_gpus(leave_unmasked=1, random=True): 326 | try: 327 | command = "nvidia-smi --query-gpu=memory.free --format=csv" 328 | memory_free_info = _output_to_list(sp.check_output(command.split()))[1:] 329 | memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)] 330 | available_gpus = [i for i, x in enumerate(memory_free_values) if x > ACCEPTABLE_AVAILABLE_MEMORY] 331 | 332 | if len(available_gpus) < leave_unmasked: 333 | print('Found only %d usable GPUs in the system' % len(available_gpus)) 334 | exit(0) 335 | 336 | if random: 337 | available_gpus = np.asarray(available_gpus) 338 | np.random.shuffle(available_gpus) 339 | 340 | # update CUDA variable 341 | gpus = available_gpus[:leave_unmasked] 342 | setting = ','.join(map(str, gpus)) 343 | os.environ["CUDA_VISIBLE_DEVICES"] = setting 344 | print('Left next %d GPU(s) unmasked: [%s] (from %s available)' 345 | % (leave_unmasked, setting, str(available_gpus))) 346 | except FileNotFoundError as e: 347 | print('"nvidia-smi" is probably not installed. GPUs are not masked') 348 | print(e) 349 | except sp.CalledProcessError as e: 350 | print("Error on GPU masking:\n", e.output) 351 | 352 | 353 | def _output_to_list(output): 354 | return output.decode('ascii').split('\n')[:-1] 355 | 356 | 357 | def get_gpu_free_session(memory_fraction=0.1): 358 | import tensorflow as tf 359 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=memory_fraction) 360 | return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 361 | 362 | 363 | def parse_params(): 364 | params = {} 365 | for i, param in enumerate(sys.argv): 366 | if '-' in param: 367 | params[param[1:]] = sys.argv[i+1] 368 | print(params) 369 | return params 370 | 371 | 372 | def print_flags(FLAGS, print=True): 373 | x = FLAGS.input_path 374 | res = 'FLAGS:' 375 | for i in sorted(FLAGS.__dict__['__flags'].items()): 376 | item = str(i)[2:-1].split('\', ') 377 | res += '\n%20s \t%s' % (item[0] + ':', item[1]) 378 | if print: 379 | print_info(res) 380 | return res 381 | 382 | 383 | def _abbreviate_string(value): 384 | str_value = str(value) 385 | abbr = [letter for letter in str_value if letter.isupper()] 386 | if len(abbr) > 1: 387 | return ''.join(abbr) 388 | 389 | if len(str_value.split('_')) > 2: 390 | parts = str_value.split('_') 391 | letters = ''.join(x[0] for x in parts) 392 | return letters 393 | return value 394 | 395 | 396 | def to_file_name(obj, folder=None, ext=None, append_timestamp=False): 397 | name, postfix = '', '' 398 | od = collections.OrderedDict(sorted(obj.items())) 399 | for _, key in enumerate(od): 400 | value = obj[key] 401 | if value is None: 402 | value = 'na' 403 | #FUNC and OBJECTS 404 | if 'function' in str(value): 405 | value = str(value).split()[1].split('.')[0] 406 | parts = value.split('_') 407 | if len(parts) > 1: 408 | value = ''.join(list(map(lambda x: x.upper()[0], parts))) 409 | elif ' at ' in str(value): 410 | value = (str(value).split()[0]).split('.')[-1] 411 | value = _abbreviate_string(value) 412 | elif isinstance(value, type): 413 | value = _abbreviate_string(value.__name__) 414 | # FLOATS 415 | if isinstance(value, float) or isinstance(value, np.float32): 416 | if value < 0.0001: 417 | value = '%.6f' % value 418 | elif value > 1000000: 419 | value = '%.0f' % value 420 | else: 421 | value = '%.4f' % value 422 | value = value.rstrip('0') 423 | #INTS 424 | if isinstance(value, int): 425 | value = '%02d' % value 426 | #LIST 427 | if isinstance(value, list): 428 | value = '|'.join(map(str, value)) 429 | 430 | truncate_threshold = 20 431 | value = _abbreviate_string(value) 432 | if len(value) > truncate_threshold: 433 | print_info('truncating this: %s %s' % (key, value)) 434 | value = value[0:20] 435 | 436 | if 'suf' in key or 'postf' in key: 437 | continue 438 | 439 | name += '__%s|%s' % (key, str(value)) 440 | 441 | if 'suf' in obj: 442 | prefix_value = obj['suf'] 443 | else: 444 | prefix_value = FLAGS.suffix 445 | if 'postf' in obj: 446 | prefix_value += '_%s' % obj['postf'] 447 | name = prefix_value + name 448 | 449 | if ext: 450 | name += '.' + ext 451 | if folder: 452 | name = os.path.join(folder, name) 453 | return name 454 | 455 | 456 | def dict_to_ordereddict(dict): 457 | return collections.OrderedDict(sorted(dict.items())) 458 | 459 | 460 | def configure_folders_2(FLAGS, meta): 461 | folder_meta = meta.copy() 462 | folder_meta.pop('init') 463 | folder_meta.pop('lr') 464 | folder_meta.pop('opt') 465 | folder_meta.pop('bs') 466 | folder_name = to_file_name(folder_meta) + '/' 467 | checkpoint_folder = os.path.join(TEMP_FOLDER, folder_name) 468 | log_folder = os.path.join(checkpoint_folder, 'log') 469 | mkdir([TEMP_FOLDER, IMAGE_FOLDER, checkpoint_folder, log_folder]) 470 | FLAGS.save_path = checkpoint_folder 471 | FLAGS.logdir = log_folder 472 | return checkpoint_folder, log_folder 473 | 474 | 475 | # precision/recall evaluation 476 | 477 | def evaluate_precision_recall(y, target, labels): 478 | import sklearn.metrics as metrics 479 | target = target[:len(y)] 480 | num_classes = max(target) + 1 481 | results = [] 482 | for i in range(num_classes): 483 | class_target = _extract_single_class(i, target) 484 | class_y = _extract_single_class(i, y) 485 | 486 | results.append({ 487 | 'precision': metrics.precision_score(class_target, class_y), 488 | 'recall': metrics.recall_score(class_target, class_y), 489 | 'f1': metrics.f1_score(class_target, class_y), 490 | 'fraction': sum(class_target)/len(target), 491 | '#of_class': int(sum(class_target)), 492 | 'label': labels[i], 493 | 'label_id': i 494 | # 'tp': tp 495 | }) 496 | print('%d/%d' % (i, num_classes), results[-1]) 497 | accuracy = metrics.accuracy_score(target, y) 498 | return accuracy, results 499 | 500 | 501 | def _extract_single_class(i, classes): 502 | res, i = classes + 1, i + 1 503 | res[res != i] = 0 504 | res = np.asarray(res) 505 | res = res / i 506 | return res 507 | 508 | 509 | 510 | def print_relevance_info(relevance, prefix='', labels=None): 511 | labels = labels if labels is not None else np.arange(len(relevance)) 512 | separator = '\n\t' if len(relevance) > 3 else ' ' 513 | result = '%s format: [" label": f1_score (precision recall) label_percentage]' % prefix 514 | format = '\x1B[0m%s\t"%25s":\x1B[31;40m%.2f\x1B[0m (%.2f %.2f) %d%%' 515 | for i, label_relevance in enumerate(relevance): 516 | result += format % (separator, 517 | str(labels[i]), 518 | label_relevance['f1'], 519 | label_relevance['precision'], 520 | label_relevance['recall'], 521 | int(label_relevance['fraction']*10000)/100. 522 | ) 523 | print(result) 524 | 525 | 526 | def disalbe_tensorflow_warnings(): 527 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 528 | 529 | 530 | @timeit 531 | def images_to_sprite(arr, path=None): 532 | assert len(arr) <= 100*100 533 | arr = arr[...,:3] 534 | resized = [scipy.misc.imresize(x[..., :3], size=[80, 80]) for x in arr] 535 | base = np.zeros([8000, 8000, 3], np.uint8) 536 | 537 | for i in range(100): 538 | for j in range(100): 539 | index = j+100*i 540 | if index < len(resized): 541 | base[80*i:80*i+80, 80*j:80*j+80] = resized[index] 542 | scipy.misc.imsave(path, base) 543 | 544 | 545 | def generate_tsv(num, path): 546 | with open(path, mode='w') as f: 547 | [f.write('%d\n' % i) for i in range(num)] 548 | 549 | 550 | def paste_patch(patch, base_size=40, upper_half=True): 551 | channels = patch.shape[-1] 552 | base = np.zeros((base_size, base_size, channels), dtype=np.uint8) 553 | 554 | position_x = random.randint(0, base_size - patch.shape[0]) 555 | position_y = random.randint(0, base_size / 2 - patch.shape[1]) 556 | if not upper_half: position_y += int(base_size / 2) 557 | 558 | base[position_x:position_x + patch.shape[0], position_y:position_y + patch.shape[1], :] = patch 559 | return base 560 | 561 | 562 | if __name__ == '__main__': 563 | data = [] 564 | for i in range(10): 565 | data.append((str(i), np.random.rand(1000))) 566 | 567 | 568 | mask_busy_gpus() -------------------------------------------------------------------------------- /video_builder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import autoencoder as ae 4 | import os 5 | import sys 6 | import utils as ut 7 | import visualization as vis 8 | import matplotlib.pyplot as plt 9 | 10 | FLAGS = tf.app.flags.FLAGS 11 | 12 | 13 | # 1. load model 14 | # 2. convert to static model 15 | # 3. fetch data 16 | # 4. feed one by one into the model -> create projection 17 | 18 | def restore_flags(path): 19 | x = FLAGS.input_path 20 | log_file = os.path.join(path, '!note.txt') 21 | 22 | with open(log_file, 'r') as f: 23 | lines = f.readlines() 24 | 25 | # collect flags 26 | flag_stored = {} 27 | for l in lines: 28 | if ': \t' in l: 29 | parts = l[:-1].split(': \t') 30 | parts[0].strip() 31 | key = parts[0].strip() 32 | val = parts[1] 33 | flag_stored[key] = val 34 | 35 | print(flag_stored) 36 | print(FLAGS.__dict__['__flags']) 37 | 38 | flag_current = FLAGS.__dict__['__flags'] 39 | for k in flag_stored.keys(): 40 | if k in flag_current: 41 | # fix int issue '10.0' => '10' 42 | if type(flag_current[k]) == int: 43 | flag_stored[k] = flag_stored[k].split('.')[0] 44 | 45 | if type(flag_current[k]) == str: 46 | flag_stored[k] = flag_stored[k][1:-1] 47 | print(flag_stored[k], k) 48 | 49 | type_friendly_val = type(flag_current[k])(flag_stored[k]) 50 | # print(k, type(dest[k]), flags[k], type(dest[k])(flags[k])) 51 | flag_current[k] = type_friendly_val 52 | 53 | 54 | # def print_video(ecn, img, reco): 55 | 56 | 57 | 58 | 59 | def data_to_img(img): 60 | h, w, c = img.shape 61 | base = np.zeros((h, 2*w, 3)) 62 | base[:, :w, :] = img[:,:,:3] 63 | base[:, w:] = np.expand_dims(img[:,:,3], axis=2) 64 | return base 65 | 66 | 67 | def _show_image(img, original=True): 68 | # index = 322 if original else 326 69 | index = (0, 1) if original else (4, 1) 70 | # ax = plt.subplot(index) 71 | ax = plt.subplot2grid((7, 2), index, rowspan=3) 72 | ax.imshow(img) 73 | ax.set_title('Original' if original else 'Reconstruction') 74 | ax.axis('off') 75 | 76 | 77 | TAIL_LENGTH = 30 78 | 79 | 80 | def animate(cur, enc, img, reco): 81 | fig = vis.get_figure() 82 | 83 | original = data_to_img(img[cur]) 84 | reconstr = data_to_img(reco[cur]) 85 | 86 | _show_image(original) 87 | _show_image(reconstr, original=False) 88 | 89 | # animation 90 | ax = plt.subplot2grid((7, 2), (0, 0), rowspan=7, projection='3d') 91 | ax.set_title('Trajectory') 92 | ax.axes.get_xaxis().set_ticks([]) 93 | ax.axes.get_yaxis().set_ticks([]) 94 | ax.set_zticks([]) 95 | 96 | if enc.shape[1] > 3: 97 | enc = enc[:, :4] 98 | 99 | # white 100 | data = enc[cur:] 101 | ax.scatter(data[:, 0], data[:, 1], data[:, 2], c='w', s=1, zorder=15) 102 | # old 103 | tail_len = max(0, cur - TAIL_LENGTH) 104 | # print(tail_len) 105 | if tail_len > 0: 106 | data = enc[:tail_len] 107 | ax.scatter(data[:, 0], data[:, 1], data[:, 2], c='black', zorder=10, s=1,) 108 | # recent 109 | for i in range(0, -TAIL_LENGTH, -1): 110 | j = i + cur 111 | if j >= 0: 112 | # print(cur, i) 113 | data = enc[j] 114 | ax.scatter(data[0], data[1], data[2], c='b', s=(i+TAIL_LENGTH)/5, zorder=5) 115 | 116 | plt.show() 117 | 118 | if __name__ == '__main__': 119 | path = os.getcwd() 120 | path = '/mnt/code/vd/TensorFlow_DCIGN/tmp/pred.16c3s2_32c3s2_32c3s2_16c3_f3__i_romb8.5.6' 121 | 122 | # write new flags: batch_size, path, save_path 123 | restore_flags(path) 124 | FLAGS.batch_size = 1 125 | 126 | path = FLAGS.input_path if len(FLAGS.test_path) == 0 else FLAGS.test_path 127 | path = '../../' + path 128 | path = sys.argv[-1] if len(sys.argv) == 3 else path 129 | FLAGS.test_path, FLAGS.input_path = path, path 130 | FLAGS.save_path = os.getcwd() 131 | FLAGS.model = 'ae' 132 | FLAGS.new_blur = False 133 | print(FLAGS.net, FLAGS.new_blur, FLAGS.test_path, FLAGS.input_path, os.getcwd()) 134 | 135 | # run inference 136 | model = ae.Autoencoder(need_forlders=False) 137 | enc, reco = model.inference(max=100) 138 | img = model.test_set 139 | 140 | # enc, img, reco = np.arange(0, 364*3).reshape((364, 3)), np.random.rand(364, 80, 160, 4), np.random.rand(364, 80, 160, 4) 141 | 142 | for i in range(len(enc)): 143 | animate(i+100, enc, img, reco) 144 | 145 | # (364, 3)(364, 80, 160, 4) 146 | 147 | -------------------------------------------------------------------------------- /visualization.py: -------------------------------------------------------------------------------- 1 | from sklearn.manifold import TSNE 2 | import sklearn.manifold as mn 3 | import matplotlib.pyplot as plt 4 | import sklearn.metrics.pairwise as pw 5 | import scipy.spatial.distance as dist 6 | import numpy as np 7 | from mpl_toolkits.mplot3d import Axes3D 8 | import os 9 | import sys 10 | import utils as ut 11 | import time 12 | import tensorflow as tf 13 | import matplotlib.ticker as ticker 14 | 15 | 16 | # Next line to silence pyflakes. This import is needed. 17 | Axes3D 18 | 19 | # colors = ['grey', 'red', 'magenta'] 20 | 21 | FLAGS = tf.app.flags.FLAGS 22 | COLOR_MAP = plt.cm.Spectral 23 | PICKER_SENSITIVITY = 5 24 | 25 | 26 | def scatter(plot, data, is3d, colors): 27 | if is3d: 28 | plot.scatter(data[:, 0], data[:, 1], data[:, 2], 29 | marker='.', 30 | c=colors, 31 | cmap=plt.cm.Spectral, 32 | picker=PICKER_SENSITIVITY) 33 | else: 34 | plot.scatter(data[:, 0], data[:, 1], 35 | c=colors, 36 | cmap=plt.cm.Spectral, 37 | picker=PICKER_SENSITIVITY) 38 | 39 | 40 | def print_data_only(data, file_name, fig=None, interactive=False): 41 | fig = get_figure(fig) 42 | subplot_number = 121 if fig is not None else 111 43 | fig.set_size_inches(fig.get_size_inches()[0] * 2, fig.get_size_inches()[1] * 1) 44 | 45 | colors = _build_radial_colors(len(data)) 46 | if data.shape[1] > 2: 47 | subplot = plt.subplot(subplot_number, projection='3d') 48 | subplot.scatter(data[:, 0], data[:, 1], data[:, 2], c=colors, 49 | cmap=COLOR_MAP, picker=PICKER_SENSITIVITY) 50 | subplot.plot(data[:, 0], data[:, 1], data[:, 2]) 51 | else: 52 | subplot = plt.subplot(subplot_number) 53 | subplot.scatter(data[:, 0], data[:, 1], c=colors, 54 | cmap=COLOR_MAP, picker=PICKER_SENSITIVITY) 55 | if not interactive: 56 | save_fig(file_name, fig) 57 | 58 | 59 | def create_gif_from_folder(folder): 60 | # fig = plt.figure() 61 | # gif_folder = os.path.join(FLAGS.save_path, 'gif') 62 | # if not os.path.exists(gif_folder): 63 | # os.mkdir(gif_folder) 64 | # epoch = file_name.split('_e|')[-1].split('_')[0] 65 | # gif_path = os.path.join(gif_folder, epoch) 66 | # subplot = plt.subplot(111, projection='3d') 67 | # subplot.scatter(data[0], data[1], data[2], c=colors, cmap=color_map) 68 | # save_fig(file_name) 69 | pass 70 | 71 | 72 | def manual_pca(data, std_threshold=0.01, num_threshold=3): 73 | """remove meaningless dimensions""" 74 | std = data[0:300].std(axis=0) 75 | 76 | order = np.argsort(std)[::-1] 77 | # order = np.arange(0, data.shape[1]).astype(np.int32) 78 | std = std[order] 79 | # filter components by STD but take at least 3 80 | 81 | meaningless = [order[i] for i, x in enumerate(std) if x <= std_threshold] 82 | if any(meaningless) and data.shape[1] > 3: 83 | # ut.print_info('meaningless dimensions on visualization: %s' % str(meaningless)) 84 | pass 85 | 86 | order = [order[i] for i, x in enumerate(std) if x > std_threshold or i < num_threshold] 87 | order.sort() 88 | return data[:, order] 89 | 90 | 91 | def _needs_hessian(manifold): 92 | if hasattr(manifold, 'dissimilarity') and manifold.dissimilarity == 'precomputed': 93 | return True 94 | if hasattr(manifold, 'metric') and manifold.metric == 'precomputed': 95 | return True 96 | return False 97 | 98 | 99 | @ut.images_to_uint8 100 | @ut.timeit 101 | def visualize_encoding(encodings, folder=None, meta={}, original=None, reconstruction=None): 102 | if np.max(original) < 10: 103 | original = (original * 255).astype(np.uint8) 104 | # print('np', np.max(original), np.max(reconstruction), np.min(original), np.min(reconstruction), 105 | # original.dtype, reconstruction.dtype) 106 | file_path = None 107 | if folder: 108 | meta['postfix'] = 'pca' 109 | file_path = ut.to_file_name(meta, folder, 'jpg') 110 | encodings = manual_pca(encodings) 111 | 112 | if original is not None: 113 | assert len(original) == len(reconstruction) 114 | fig = get_figure() 115 | 116 | # print('reco max:', np.max(reconstruction)) 117 | column_picture, height = _stitch_images(original, reconstruction) 118 | subplot, proportion = (122, 1) if encodings.shape[1] <= 3 else (155, 3) 119 | picture = _reshape_column_image(column_picture, height, proportion=proportion) 120 | if picture.shape[-1] == 1: 121 | picture = picture.squeeze() 122 | plt.subplot(subplot).set_title("Original/reconstruction") 123 | plt.subplot(subplot).imshow(picture) 124 | plt.subplot(subplot).axis('off') 125 | 126 | visualize_encodings(encodings, file_name=file_path, fig=fig, grid=(3, 5), skip_every=5) 127 | else: 128 | visualize_encodings(encodings, file_name=file_path) 129 | 130 | 131 | # @ut.timeit 132 | def plot_encoding_crosssection(encodings, file_path, original=None, reconstruction=None, interactive=False): 133 | # print(encodings.shape) 134 | # print(original.shape) 135 | # print(reconstruction.shape) 136 | # encodings = manual_pca(encodings) 137 | 138 | fig = get_figure() 139 | if original is not None: 140 | assert len(original) == len(reconstruction), (len(original), len(reconstruction)) 141 | subplot, proportion = visualize_cross_section_with_reco(encodings) 142 | picture = stitch_side_by_side(original, reconstruction, proportion) 143 | subplot.imshow(picture) 144 | else: 145 | visualize_cross_section(encodings) 146 | if not interactive: 147 | save_fig(file_path, fig) 148 | return fig 149 | 150 | 151 | def plot_reconstruction(original, reconstruction, meta={'debug': 'true'}, interactive=False): 152 | # if not interactive: 153 | # _get_figure() 154 | picture = stitch_side_by_side(original, reconstruction) 155 | plt.imshow(picture) 156 | if not interactive: 157 | file_path = ut.to_file_name(meta, FLAGS.save_path, 'jpg') 158 | save_fig(file_path) 159 | else: 160 | plt.draw() 161 | plt.pause(0.001) 162 | 163 | 164 | # Image stitching 165 | 166 | 167 | @ut.images_to_uint8 168 | def stitch_side_by_side(original, reconstruction, proportion=1): 169 | """ 170 | Stitch 2 lists of images together for convenient display in a single 171 | rectangular shape of given side proportion 172 | """ 173 | # print(np.max(original), original.dtype) 174 | if not np.max(original) >= 10: 175 | print("some strange input to of the original pictures", np.max(original), original.dtype) 176 | column_picture, height = _stitch_images(original, reconstruction) 177 | picture = _reshape_column_image(column_picture, height, proportion=proportion) 178 | if picture.shape[-1] == 1: 179 | picture = picture.squeeze() 180 | return picture 181 | 182 | 183 | def _stitch_images(*args): 184 | """Recieves one or many arrays of pictures and stitches them alongside into a column picture""" 185 | assert len(args) == 2 186 | lines, height, width, channels = args[0].shape 187 | min = 0 188 | stack = args[0] 189 | # print([(type(x), x.shape) for x in args]) 190 | # print(np.min(args[0]), np.max(args[0]), args[0].dtype, args[0].mean()) 191 | # print(np.min(args[1]), np.max(args[1]), args[1].dtype, args[1].mean()) 192 | if len(args) > 1: 193 | for i in range(len(args) - 1): 194 | stack = np.concatenate((stack, args[i + 1]), axis=2) 195 | # stack - array of lines of pictures (arr_0[0], arr_1[0], ...) 196 | # concatenate lines in one picture (height = tile_h * #lines) 197 | picture_lines = stack.reshape(lines * height, stack.shape[2], channels) 198 | picture_lines = np.hstack(( 199 | picture_lines, 200 | np.ones((lines * height, 2, channels), dtype=np.uint8) * min)) # pad 2 pixels 201 | 202 | # slice/reshape to have better image proportions 203 | return picture_lines, height 204 | 205 | 206 | def _reshape_column_image(column_picture, height, proportion=1): 207 | """ 208 | Take an "column" image of independent horizontal segments of 'height' 209 | and reshape it to fit into given proportion 210 | 211 | |img1| 212 | |img2| |img1 img4| 213 | |img3| => |img2 img5| 214 | |img4| |img3 | 215 | |img5| 216 | 217 | """ 218 | lines = int(column_picture.shape[0] / height) 219 | width = column_picture.shape[1] 220 | 221 | column_size = int(np.ceil(np.sqrt(lines * proportion * (width / height)))) 222 | 223 | # get rid of too small columns. If singl last column contains 1 or 2 pictures 224 | if lines % column_size <= lines / column_size: 225 | column_size += 1 226 | 227 | count = int(column_picture.shape[0] / height) 228 | _, _, channels = column_picture.shape 229 | 230 | picture = column_picture[0:column_size * height, :, :] 231 | 232 | for i in range(int(lines / column_size)): 233 | start, stop = column_size * height * (i + 1), column_size * height * (i + 2) 234 | if start >= len(column_picture): 235 | break 236 | if stop < len(column_picture): 237 | picture = np.hstack((picture, column_picture[start:stop])) 238 | else: 239 | last_column = np.vstack(( 240 | column_picture[start:], 241 | np.ones((stop - len(column_picture), column_picture.shape[1], column_picture.shape[2]), 242 | dtype=np.uint8))) 243 | picture = np.hstack((picture, last_column)) 244 | 245 | return picture 246 | 247 | 248 | figure_shape = [3095, 2352, 3] 249 | 250 | 251 | def get_figure(fig=None, shape=figure_shape): 252 | if fig is not None: 253 | return fig 254 | dpi = 300. 255 | fig = plt.figure(num=0, figsize=[shape[0]/dpi, shape[1]/dpi], dpi=dpi) 256 | fig.clf() 257 | return fig 258 | 259 | 260 | def _plot_single_cross_section(data, select, subplot): 261 | data = data[:, select] 262 | # subplot.scatter(data[:, 0], data[:, 1], s=20, lw=0, edgecolors='none', alpha=1.0, 263 | subplot.plot(data[:, 0], data[:, 1], color='black', lw=1, alpha=0.4) 264 | subplot.plot(data[[-1, 0], 0], data[[-1, 0], 1], lw=1, alpha=0.8, color='red') 265 | subplot.scatter(data[:, 0], data[:, 1], s=4, alpha=1.0, lw=0.5, 266 | c=_build_radial_colors(len(data)), 267 | marker=".", 268 | cmap=plt.cm.Spectral) 269 | # data = np.vstack((data, np.asarray([data[0, :]]))) 270 | # subplot.plot(data[:, 0], data[:, 1], alpha=0.4) 271 | 272 | subplot.set_xlabel('feature %d' % select[0], labelpad=-12) 273 | subplot.set_ylabel('feature %d' % select[1], labelpad=-12) 274 | subplot.set_xlim([-0.05, 1.05]) 275 | subplot.set_ylim([-0.05, 1.05]) 276 | subplot.xaxis.set_ticks([0, 1]) 277 | subplot.xaxis.set_major_formatter(ticker.FormatStrFormatter('%1.0f')) 278 | subplot.yaxis.set_ticks([0, 1]) 279 | subplot.yaxis.set_major_formatter(ticker.FormatStrFormatter('%1.0f')) 280 | 281 | 282 | def _plot_single_cross_section_3d(data, select, subplot): 283 | data = data[:, select] 284 | # subplot.scatter(data[:, 0], data[:, 1], s=20, lw=0, edgecolors='none', alpha=1.0, 285 | subplot.plot(data[:, 0], data[:, 1], data[:, 2], color='black', lw=1, alpha=0.4) 286 | subplot.plot(data[[-1, 0], 0], data[[-1, 0], 1], data[[-1, 0], 2], lw=1, alpha=0.8, color='red') 287 | subplot.scatter(data[:, 0], data[:, 1], data[:, 2], s=4, alpha=1.0, lw=0.5, 288 | c=_build_radial_colors(len(data)), 289 | marker=".", 290 | cmap=plt.cm.Spectral) 291 | data = data[0::10] 292 | # subplot.plot(data[:, 0], data[:, 1], data[:, 2], color='black', lw=2, alpha=0.8) 293 | 294 | # data = np.vstack((data, np.asarray([data[0, :]]))) 295 | # subplot.plot(data[:, 0], data[:, 1], alpha=0.4) 296 | 297 | subplot.set_xlabel('feature %d' % select[0], labelpad=-12) 298 | subplot.set_ylabel('feature %d' % select[1], labelpad=-12) 299 | subplot.set_zlabel('feature %d' % select[2], labelpad=-12) 300 | subplot.set_xlim([-0.01, 1.01]) 301 | subplot.set_ylim([-0.01, 1.01]) 302 | subplot.set_zlim([-0.01, 1.01]) 303 | subplot.xaxis.set_ticks([0, 1]) 304 | subplot.yaxis.set_ticks([0, 1]) 305 | subplot.zaxis.set_ticks([0, 1]) 306 | subplot.xaxis.set_major_formatter(ticker.FormatStrFormatter('%1.0f')) 307 | subplot.yaxis.set_major_formatter(ticker.FormatStrFormatter('%1.0f')) 308 | 309 | 310 | @ut.deprecated 311 | def visualize_cross_section(embeddings): 312 | features = embeddings.shape[-1] 313 | size = features - 1 314 | for i in range(features): 315 | for j in range(i + 1, features): 316 | pos = i * size + j 317 | subplot = plt.subplot(size, size, pos) 318 | _plot_single_cross_section(embeddings, [i, j], subplot) 319 | 320 | if features >= 3: 321 | pos = (size + 1) * size - size + 1 322 | subplot = plt.subplot(size + 1, size, pos) 323 | _plot_single_cross_section(embeddings, [0, 1], subplot) 324 | return size 325 | 326 | 327 | def visualize_cross_section_with_reco(embeddings): 328 | features = embeddings.shape[-1] 329 | size = features - 1 330 | for i in range(features): 331 | for j in range(i + 1, features): 332 | pos = i * (size + 1) + j 333 | subplot = plt.subplot(size, size + 1, pos) 334 | _plot_single_cross_section(embeddings, [i, j], subplot) 335 | reco_subplot = plt.subplot(1, size + 1, size + 1) 336 | reco_subplot.axis('off') 337 | 338 | if size >= 2: 339 | single_size = size if size < 4 else int(size/2) 340 | pos = single_size * (single_size + 1) - (single_size + 1) + 1 341 | subplot = plt.subplot(single_size, single_size+1, pos, projection='3d') 342 | _plot_single_cross_section_3d(embeddings, [0, 1, 2], subplot) 343 | return reco_subplot, size 344 | 345 | 346 | # cross section end 347 | 348 | 349 | def visualize_encodings(encodings, file_name=None, 350 | grid=None, skip_every=999, fast=False, fig=None, interactive=False): 351 | encodings = manual_pca(encodings) 352 | if encodings.shape[1] <= 3: 353 | return print_data_only(encodings, file_name, fig=fig, interactive=interactive) 354 | 355 | encodings = encodings[0:720] 356 | hessian_euc = dist.squareform(dist.pdist(encodings[0:720], 'euclidean')) 357 | hessian_cos = dist.squareform(dist.pdist(encodings[0:720], 'cosine')) 358 | grid = (3, 4) if grid is None else grid 359 | project_ops = [] 360 | 361 | n = 2 362 | project_ops.append(("LLE ltsa N:%d" % n, mn.LocallyLinearEmbedding(10, n, method='ltsa'))) 363 | project_ops.append(("LLE modified N:%d" % n, mn.LocallyLinearEmbedding(10, n, method='modified'))) 364 | project_ops.append(('MDS euclidean N:%d' % n, mn.MDS(n, max_iter=300, n_init=1, dissimilarity='precomputed'))) 365 | project_ops.append(("TSNE 30/2000 N:%d" % n, TSNE(perplexity=30, n_components=n, init='pca', n_iter=2000))) 366 | n = 3 367 | project_ops.append(("LLE ltsa N:%d" % n, mn.LocallyLinearEmbedding(10, n, method='ltsa'))) 368 | project_ops.append(("LLE modified N:%d" % n, mn.LocallyLinearEmbedding(10, n, method='modified'))) 369 | project_ops.append(('MDS euclidean N:%d' % n, mn.MDS(n, max_iter=300, n_init=1, dissimilarity='precomputed'))) 370 | project_ops.append(('MDS cosine N:%d' % n, mn.MDS(n, max_iter=300, n_init=1, dissimilarity='precomputed'))) 371 | 372 | plot_places = [] 373 | for i in range(12): 374 | u, v = int(i / (skip_every - 1)), i % (skip_every - 1) 375 | j = v + u * skip_every + 1 376 | plot_places.append(j) 377 | 378 | fig = get_figure(fig) 379 | fig.set_size_inches(fig.get_size_inches()[0] * grid[0] / 1., 380 | fig.get_size_inches()[1] * grid[1] / 2.0) 381 | 382 | for i, (name, manifold) in enumerate(project_ops): 383 | is3d = 'N:3' in name 384 | 385 | try: 386 | if is3d: 387 | subplot = plt.subplot(grid[0], grid[1], plot_places[i], projection='3d') 388 | else: 389 | subplot = plt.subplot(grid[0], grid[1], plot_places[i]) 390 | 391 | data_source = encodings if not _needs_hessian(manifold) else \ 392 | (hessian_cos if 'cosine' in name else hessian_euc) 393 | projections = manifold.fit_transform(data_source) 394 | scatter(subplot, projections, is3d, _build_radial_colors(len(data_source))) 395 | subplot.set_title(name) 396 | except: 397 | print(name, "Unexpected error: ", sys.exc_info()[0], sys.exc_info()[1] if len(sys.exc_info()) > 1 else '') 398 | 399 | visualize_data_same(encodings, grid=grid, places=plot_places[-4:]) 400 | if not interactive: 401 | save_fig(file_name, fig) 402 | ut.print_time('visualization finished') 403 | 404 | 405 | def save_fig(file_name, fig=None): 406 | if not file_name: 407 | plt.show() 408 | else: 409 | plt.savefig(file_name, dpi=300, facecolor='w', edgecolor='w', 410 | transparent=False, bbox_inches='tight', pad_inches=0.1, 411 | frameon=None) 412 | # plt.close('all') 413 | 414 | 415 | def _random_split(sequence, length, original): 416 | if sequence is None or len(sequence) < length: 417 | sequence = original.copy() 418 | sequence = np.random.permutation(sequence) 419 | return sequence[:length], sequence[length:] 420 | 421 | 422 | @ut.deprecated 423 | def visualize_data_same(data, grid, places): 424 | assert len(places) == 4 425 | 426 | all_dimensions = np.arange(0, data.shape[1]).astype(np.int8) 427 | first_proj, left = _random_split(None, 2, all_dimensions) 428 | first_color_indexes, _ = _random_split(left, 3, all_dimensions) 429 | first_color = _data_to_colors(data, first_color_indexes - 2) 430 | 431 | second_proj, left = _random_split(left, 2, all_dimensions) 432 | second_color = _build_radial_colors(len(data)) 433 | 434 | third_proj, left = _random_split(left, 3, all_dimensions) 435 | third_color_indexes, _ = _random_split(left, 3, all_dimensions) 436 | third_color = _data_to_colors(data, third_color_indexes) 437 | 438 | forth_proj = np.argsort(data.std(axis=0))[::-1][0:3] 439 | forth_color = _build_radial_colors(len(data)) 440 | 441 | for i, (projection, color) in enumerate([ 442 | (first_proj, first_color), 443 | (second_proj, second_color), 444 | (third_proj, third_color), 445 | (forth_proj, forth_color)] 446 | ): 447 | points = np.transpose(data[:, projection]) 448 | 449 | if len(projection) == 2: 450 | subplot = plt.subplot(grid[0], grid[1], places[i]) 451 | subplot.scatter(points[0], points[1], c=color, cmap=COLOR_MAP, picker=PICKER_SENSITIVITY) 452 | else: 453 | subplot = plt.subplot(grid[0], grid[1], places[i], projection='3d') 454 | subplot.scatter(points[0], points[1], points[2], c=color, cmap=COLOR_MAP, picker=PICKER_SENSITIVITY) 455 | subplot.set_title('Data %s %s' % (str(projection), 'sequntial color' if i % 2 == 1 else '')) 456 | 457 | 458 | @ut.deprecated 459 | def visualize_data_same_deprecated(data, grid, places, dims_as_colors=False): 460 | assert len(places) == 4 461 | dimensions = np.arange(0, np.min([6, data.shape[1]])).astype(np.int) 462 | assert len(dimensions) == data.shape[1] or len(dimensions) == 6 463 | projections = [dimensions[x] for x in [[0, 1], [-1, -2], [0, 1, 2], [-1, -2, -3]]] 464 | colors = _build_radial_colors(len(data)) 465 | 466 | for i, dims in enumerate(projections): 467 | points = np.transpose(data[:, dims]) 468 | if dims_as_colors: 469 | colors = _data_to_colors(np.delete(data.copy(), dims, axis=1)) 470 | 471 | if len(dims) == 2: 472 | subplot = plt.subplot(grid[0], grid[1], places[i]) 473 | subplot.scatter(points[0], points[1], c=colors, cmap=COLOR_MAP) 474 | else: 475 | subplot = plt.subplot(grid[0], grid[1], places[i], projection='3d') 476 | subplot.scatter(points[0], points[1], points[2], c=colors, cmap=COLOR_MAP) 477 | subplot.set_title('Data %s' % str(dims)) 478 | 479 | 480 | def _duplicate_array(array, repeats=None, total_length=None): 481 | assert repeats is not None or total_length is not None 482 | 483 | if repeats is None: 484 | repeats = int(np.ceil(total_length / len(array))) 485 | res = array.copy() 486 | for i in range(repeats - 1): 487 | res = np.concatenate((res, array)) 488 | return res if total_length is None else res[:total_length] 489 | 490 | 491 | def _duplicate_array_repeat(array, repeats=None, total_length=None): 492 | assert repeats is not None or total_length is not None 493 | 494 | if repeats is None: 495 | repeats = int(np.ceil(total_length / len(array))) 496 | parts = [array for i in range(repeats)] 497 | whole = np.stack(parts, axis=1) 498 | whole = whole.reshape(np.prod(whole.shape)) 499 | return whole if total_length is None else whole[:total_length] 500 | 501 | 502 | def _build_radial_colors(length): 503 | colors = np.arange(0, 180) 504 | # colors = np.concatenate((colors, colors[::-1])) 505 | colors = _duplicate_array_repeat(colors, total_length=length) 506 | return colors 507 | 508 | 509 | def _data_to_colors(data, indexes=None): 510 | color_data = data[:, indexes] if indexes is not None else data 511 | shape = color_data.shape 512 | 513 | if shape[1] < 3: 514 | add = 3 - shape[1] 515 | add = np.ones((shape[0], add)) * 0.5 516 | color_data = np.concatenate((color_data, add), axis=1) 517 | elif shape[1] > 3: 518 | color_data = color_data[:, 0:3] 519 | 520 | if np.max(color_data) <= 1: 521 | color_data *= 256 522 | color_data = color_data.astype(np.int32) 523 | assert np.mean(color_data) <= 256 524 | color_data[color_data > 255] = 255 525 | color_data *= np.asarray([256 ** 2, 256, 1]) 526 | 527 | color_data = np.sum(color_data, axis=1) 528 | color_data = ["#%06x" % c for c in color_data] 529 | return color_data 530 | 531 | 532 | def visualize_available_data(root='./', reembed=True, with_mds=False): 533 | tf.app.flags.DEFINE_string('suffix', 'grid', '') 534 | FLAGS.suffix = '__' if with_mds else 'grid' 535 | assert os.path.exists(root) 536 | 537 | files = _list_embedding_files(root, reembed=reembed) 538 | total_files = len(files) 539 | 540 | for i, file in enumerate(files): 541 | print('%d/%d %s' % (i + 1, total_files, file)) 542 | data = np.loadtxt(file) 543 | png_name = file.replace('.txt', '_pca.png') 544 | 545 | visualize_encodings(data, file_name=png_name) 546 | 547 | 548 | def _list_embedding_files(root, reembed=False): 549 | ecndoding_files = [] 550 | for root, dirs, files in os.walk(root): 551 | if '/tmp' in root: 552 | for file in files: 553 | if '.txt' in file and 'meta' not in file: 554 | full_path = os.path.join(root, file) 555 | if not reembed: 556 | vis_file = full_path.replace('.txt', '_pca.png') 557 | if os.path.exists(vis_file): 558 | continue 559 | ecndoding_files.append(full_path) 560 | return ecndoding_files 561 | 562 | 563 | if __name__ == '__main__': 564 | path = '../../encodings__e|500__z_ac|96.5751.txt' 565 | x = np.loadtxt(path) 566 | # x = np.random.rand(100, 5) 567 | x = manual_pca(x) 568 | x = x[:360] 569 | visualize_cross_section_with_reco(x) 570 | plt.tight_layout() 571 | plt.show() 572 | -------------------------------------------------------------------------------- /visualize_latest.py: -------------------------------------------------------------------------------- 1 | import visualization as vi 2 | import utils as ut 3 | import input as inp 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import tensorflow as tf 7 | import os 8 | import sys 9 | import visualization as vis 10 | 11 | from mpl_toolkits.mplot3d import Axes3D 12 | 13 | # Next line to silence pyflakes. This import is needed. 14 | Axes3D 15 | 16 | FLAGS = tf.app.flags.FLAGS 17 | 18 | 19 | def print_data(data, fig, subplot, is_3d=True): 20 | colors = np.arange(0, 180) 21 | colors = np.concatenate((colors, colors[::-1])) 22 | colors = vi._duplicate_array(colors, total_length=len(data)) 23 | 24 | if is_3d: 25 | subplot = fig.add_subplot(subplot, projection='3d') 26 | subplot.set_title('All data') 27 | subplot.scatter(data[:, 0], data[:, 1], data[:, 2], c=colors, cmap=plt.cm.Spectral, picker=5) 28 | else: 29 | subsample = data[0:360] if len(data) < 2000 else data[0:720] 30 | subsample = np.concatenate((subsample, subsample))[0:len(subsample)+1] 31 | ut.print_info('subsample shape %s' % str(subsample.shape)) 32 | subsample_colors = colors[0:len(subsample)] 33 | subplot = fig.add_subplot(subplot) 34 | subplot.set_title('First 360 elem') 35 | subplot.plot(subsample[:, 0], subsample[:, 1], picker=0) 36 | subplot.plot(subsample[0, 0], subsample[0, 1], picker=0) 37 | subplot.scatter(subsample[:, 0], subsample[:, 1], s=50, c=subsample_colors, 38 | cmap=plt.cm.Spectral, picker=5) 39 | return subplot 40 | 41 | 42 | class EncodingVisualizer: 43 | def __init__(self, fig, data): 44 | self.data = data 45 | self.fig = fig 46 | vi.visualize_encodings(data, grid=(3, 5), skip_every=5, fast=fast, fig=fig, interactive=True) 47 | plt.subplot(155).set_title(', '.join('hold on')) 48 | # fig.canvas.mpl_connect('button_press_event', self.on_click) 49 | fig.canvas.mpl_connect('pick_event', self.on_pick) 50 | try: 51 | # if True: 52 | ut.print_info('Checkpoint: %s' % FLAGS.load_from_checkpoint) 53 | self.model = dm.DoomModel() 54 | self.reconstructions = self.model.decode(data) 55 | except: 56 | ut.print_info("Model could not load from checkpoint %s" % str(sys.exc_info()), color=31) 57 | self.original_data, _ = inp.get_images(FLAGS.input_path) 58 | self.reconstructions = np.zeros(self.original_data.shape).astype(np.uint8) 59 | ut.print_info('INPUT: %s' % FLAGS.input_path.split('/')[-3]) 60 | self.original_data, _ = inp.get_images(FLAGS.input_path) 61 | 62 | 63 | def on_pick(self, event): 64 | print(event) 65 | ind = event.ind 66 | print(ind) 67 | print(any([x for x in ind if x < 20])) 68 | orig = self.original_data[ind] 69 | reco = self.reconstructions[ind] 70 | column_picture, height = vi._stitch_images(orig, reco) 71 | picture = vi._reshape_column_image(column_picture, height, proportion=3) 72 | 73 | title = '' 74 | for i in range(len(ind)): 75 | title += ' ' + str(ind[i]) 76 | if (i+1) % 8 == 0: 77 | title += '\n' 78 | plt.subplot(155).set_title(title) 79 | plt.subplot(155).imshow(picture) 80 | plt.show() 81 | 82 | def on_click(self, event): 83 | print('click', event) 84 | 85 | 86 | def visualize_latest_from_visualization_folder(folder='./visualizations/', file=None): 87 | if file is None: 88 | file = ut.get_latest_file(folder, filter=r'.*\d+\.txt$') 89 | ut.print_info('Encoding file: %s' % file.split('/')[-1]) 90 | data = np.loadtxt(file) # [0:360] 91 | fig = plt.figure() 92 | vi.visualize_encodings(data, fast=fast, fig=fig, interactive=True) 93 | fig.suptitle(file.split('/')[-1]) 94 | fig.tight_layout() 95 | plt.show() 96 | 97 | 98 | def visualize_from_checkpoint(checkpoint, epoch=None): 99 | assert os.path.exists(checkpoint) 100 | FLAGS.load_from_checkpoint = checkpoint 101 | file_filter = r'.*\d+\.txt$' if epoch is None else r'.*e\|%d.*' % epoch 102 | latest_file = ut.get_latest_file(folder=checkpoint, filter=file_filter) 103 | print(latest_file) 104 | ut.print_info('Encoding file: %s' % latest_file.split('/')[-1]) 105 | data = np.loadtxt(latest_file) 106 | fig = plt.figure() 107 | fig.set_size_inches(fig.get_size_inches()[0] * 2, fig.get_size_inches()[1] * 2) 108 | entity = EncodingVisualizer(fig, data) 109 | # fig.tight_layout() 110 | plt.show() 111 | 112 | 113 | fast = True 114 | 115 | if __name__ == '__main__': 116 | 117 | cwd = os.getcwd() 118 | # cwd = '/mnt/code/vd/TensorFlow_DCIGN/tmp/pred.16c3s2_32c3s2_32c3s2_16c3_f80_f8__i_grid.28c.4' 119 | latest = ut.get_latest_file(cwd, filter=r'.*_suf\.encodings\.npy$') 120 | print(latest) 121 | data = np.load(latest).item() 122 | # print(type(data)) 123 | # i = data.item() 124 | # print(type(i)) 125 | # print(i.shape) 126 | # print(data['enc']) 127 | 128 | # print(data) 129 | x = data['enc'] 130 | 131 | # print(x) 132 | 133 | fig = vis.plot_encoding_crosssection( 134 | x, 135 | '', 136 | data['blu'], 137 | data['rec'], 138 | interactive=True) 139 | fig.set_size_inches(fig.get_size_inches()[0] * 2, fig.get_size_inches()[1] * 2) 140 | # plt.tight_layout() 141 | plt.show() 142 | 143 | # path = sys.argv[1] if len(sys.argv) > 1 \ 144 | # else './tmp/ml__act|sigmoid__bs|30__h|500|10|500__init|na__inp|8pd3__lr|0.00003__opt|AO__seq|03' 145 | # epoch = int(sys.argv[2]) if len(sys.argv) > 2 else None 146 | # 147 | # # path = './tmp/doom_bs__act|sigmoid__bs|30__h|500|12|500__init|na__inp|8pd3__lr|0.0004__opt|AO/' 148 | # 149 | # # import os 150 | # # print('really? ', ) 151 | # 152 | # if path is None: 153 | # ut.print_info('Visualizing latest file from visualization folder') 154 | # visualize_latest_from_visualization_folder() 155 | # exit(0) 156 | # 157 | # is_embedding = '.txt' in path 158 | # if is_embedding: 159 | # ut.print_info('Visualizing encoding file') 160 | # visualize_latest_from_visualization_folder(file=path) 161 | # exit(0) 162 | # 163 | # is_checkpoint = '/tmp' in path 164 | # if is_checkpoint: 165 | # print('so', path) 166 | # ut.print_info('Visualizing checkpoint data') 167 | # visualize_from_checkpoint(checkpoint=path, epoch=epoch) 168 | # else: 169 | # ut.print_info('Visualizing latest from folder', color=34) 170 | # visualize_latest_from_visualization_folder(folder=path) 171 | --------------------------------------------------------------------------------