├── .gitignore
├── .idea
├── TensorFlow_DCIGN.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── workspace.xml
├── Bunch.py
├── DCIGNModel.py
├── IGNModel.py
├── LICENSE
├── Model.py
├── README.md
├── __init__.py
├── activation_functions.py
├── autoencoder.py
├── data_preprocessing
├── generate.command
├── movie.command
├── resize.command
├── resize.sh
├── resize_grey.command
├── resize_no_gun.command
├── tar_all.command
└── to32_32.command
├── experiments.py
├── input.py
├── metrics.py
├── model_interpreter.py
├── model_interpreter_test.py
├── network_utils.py
├── tools
├── __init__.py
├── checkpoint_utils.py
├── freeze_graph.py
├── freeze_graph_test.py
├── graph_metrics.py
├── graph_metrics_test.py
├── inspect_checkpoint.py
├── strip_unused.py
└── strip_unused_test.py
├── utils.py
├── video_builder.py
├── visualization.py
└── visualize_latest.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 |
91 |
92 | # Custom
93 | .idea/
94 | tmp/
95 |
--------------------------------------------------------------------------------
/.idea/TensorFlow_DCIGN.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | 1484656314255
19 |
20 |
21 | 1484656314255
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/Bunch.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | class Bunch(object):
5 | def __init__(self, **kwds):
6 | self.__dict__.update(kwds)
7 |
8 | def __eq__(self, other):
9 | return self.__dict__ == other.__dict__
10 |
11 | def __str__(self):
12 | string = "BUNCH {" + str(self.__dict__)[2:]
13 | string = string.replace("': ", ":")
14 | string = string.replace(", '", ", ")
15 | return string
16 |
17 | def __repr__(self):
18 | return str(self)
19 |
20 | def to_file_name(self, folder=None, ext=None):
21 | res = str(self.__dict__)[2:-1]
22 | res = res.replace("'", "")
23 | res = res.replace(": ", ".")
24 | parts = res.split(', ')
25 | res = '_'.join(sorted(parts))
26 |
27 | if ext is not None:
28 | res = '%s.%s' % (res, ext)
29 | if folder is not None:
30 | res = os.path.join(folder, res)
31 | return res
32 |
33 |
34 | if __name__ == '__main__':
35 | b = Bunch(x=5, y='something', other=9.0)
36 | print(b)
37 | print(b.to_file_name())
38 | print(b.to_file_name('./here', 'txt'))
39 |
40 |
41 |
--------------------------------------------------------------------------------
/DCIGNModel.py:
--------------------------------------------------------------------------------
1 | """MNIST Autoencoder. """
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from six.moves import xrange # pylint: disable=redefined-builtin
8 | import tensorflow as tf
9 | import json, os, re, math
10 | import numpy as np
11 | import utils as ut
12 | import input as inp
13 | import tools.checkpoint_utils as ch_utils
14 | import activation_functions as act
15 | import visualization as vis
16 | import prettytensor as pt
17 | import prettytensor.bookkeeper as bookkeeper
18 | import deconv
19 | from tensorflow.python.ops import gradients
20 | from prettytensor.tutorial import data_utils
21 | import IGNModel
22 |
23 | FLAGS = tf.app.flags.FLAGS
24 |
25 | DEV = False
26 |
27 |
28 | class DCIGNModel(IGNModel.IGNModel):
29 | model_id = 'dcign'
30 |
31 | def _build_encoder(self):
32 | """Construct encoder network: placeholders, operations, optimizer"""
33 | self._input = tf.placeholder(tf.float32, self._batch_shape, name='input')
34 | self._encoding = tf.placeholder(tf.float32, (FLAGS.batch_size, self.layer_narrow), name='encoding')
35 |
36 | self._encode = (pt.wrap(self._input)
37 | .flatten()
38 | .fully_connected(self.layer_encoder, name='enc_hidden')
39 | .fully_connected(self.layer_narrow, name='narrow'))
40 |
41 | self._encode = pt.wrap(self._input)
42 | self._encode = self._encode.conv2d(5, 32, stride=2)
43 | print(self._encode.get_shape())
44 | self._encode = self._encode.conv2d(5, 64, stride=2)
45 | print(self._encode.get_shape())
46 | self._encode = self._encode.conv2d(5, 128, stride=2)
47 | print(self._encode.get_shape())
48 | self._encode = (self._encode.dropout(0.9).
49 | flatten().
50 | fully_connected(self.layer_narrow, activation_fn=None))
51 |
52 | # variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.encoder_scope)
53 | self._encoder_loss = self._encode.l1_regression(pt.wrap(self._encoding))
54 | ut.print_info('new learning rate: %.8f (%f)' % (FLAGS.learning_rate/FLAGS.batch_size, FLAGS.learning_rate))
55 | self._opt_encoder = self._optimizer(learning_rate=FLAGS.learning_rate/FLAGS.batch_size)
56 | self._train_encoder = self._opt_encoder.minimize(self._encoder_loss)
57 |
58 | def _build_decoder(self, weight_init=tf.truncated_normal):
59 | """Construct decoder network: placeholders, operations, optimizer,
60 | extract gradient back-prop for encoding layer"""
61 | self._clamped = tf.placeholder(tf.float32, (FLAGS.batch_size, self.layer_narrow))
62 | self._reconstruction = tf.placeholder(tf.float32, self._batch_shape)
63 |
64 | clamped_init = np.zeros((FLAGS.batch_size, self.layer_narrow), dtype=np.float32)
65 | self._clamped_variable = tf.Variable(clamped_init, name='clamped')
66 | self._assign_clamped = tf.assign(self._clamped_variable, self._clamped)
67 |
68 | self._decode = pt.wrap(self._clamped_variable)
69 | # self._decode = self._decode.reshape([FLAGS.batch_size, 1, 1, self.layer_narrow])
70 | print(self._decode.get_shape())
71 | self._decode = self._decode.fully_connected(7200)
72 | self._decode = self._decode.reshape([FLAGS.batch_size, 1, 1, 7200])
73 | self._decode = self._decode.deconv2d((10, 20), 128, edges='VALID')
74 | print(self._decode.get_shape())
75 | self._decode = self._decode.deconv2d(5, 64, stride=2)
76 | print(self._decode.get_shape())
77 | self._decode = self._decode.deconv2d(5, 32, stride=2)
78 | print(self._decode.get_shape())
79 | self._decode = self._decode.deconv2d(5, 3, stride=2, activation_fn=tf.nn.sigmoid)
80 | print(self._decode.get_shape())
81 |
82 | # variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.decoder_scope)
83 | self._decoder_loss = self._decode.l2_regression(pt.wrap(self._reconstruction))
84 | self._opt_decoder = self._optimizer(learning_rate=FLAGS.learning_rate/FLAGS.batch_size)
85 | self._train_decoder = self._opt_decoder.minimize(self._decoder_loss)
86 |
87 | self._clamped_grad, = tf.gradients(self._decoder_loss, [self._clamped_variable])
88 |
89 |
90 | def parse_params():
91 | params = {}
92 | for i, param in enumerate(sys.argv):
93 | if '-' in param:
94 | params[param[1:]] = sys.argv[i+1]
95 | print(params)
96 | return params
97 |
98 |
99 | if __name__ == '__main__':
100 | epochs = 500
101 | import sys
102 |
103 | FLAGS.save_every = 5
104 | FLAGS.save_encodings_every = 2
105 |
106 |
107 | model = DCIGNModel()
108 | args = dict([arg.split('=', maxsplit=1) for arg in sys.argv[1:]])
109 | if len(args) == 0:
110 | global DEV
111 | DEV = False
112 | print('DEVELOPMENT MODE ON')
113 | print(args)
114 | if 'epochs' in args:
115 | epochs = int(args['epochs'])
116 | ut.print_info('epochs: %d' % epochs, color=36)
117 | if 'sigma' in args:
118 | FLAGS.sigma = int(args['sigma'])
119 | if 'suffix' in args:
120 | FLAGS.suffix = args['suffix']
121 | if 'input' in args:
122 | parts = FLAGS.input_path.split('/')
123 | parts[-3] = args['input']
124 | FLAGS.input_path = '/'.join(parts)
125 | ut.print_info('input %s' % FLAGS.input_path, color=36)
126 | if 'h' in args:
127 | layers = list(map(int, args['h'].split('/')))
128 | ut.print_info('layers %s' % str(layers), color=36)
129 | model.set_layer_sizes(layers)
130 | if 'divider' in args:
131 | FLAGS.drag_divider = float(args['divider'])
132 | if 'lr' in args:
133 | FLAGS.learning_rate = float(args['lr'])
134 |
135 | model.train(epochs)
136 |
--------------------------------------------------------------------------------
/IGNModel.py:
--------------------------------------------------------------------------------
1 | """MNIST Autoencoder. """
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from six.moves import xrange # pylint: disable=redefined-builtin
8 | import tensorflow as tf
9 | import json, os, re, math
10 | import numpy as np
11 | import utils as ut
12 | import input as inp
13 | import tools.checkpoint_utils as ch_utils
14 | import activation_functions as act
15 | import visualization as vis
16 | import prettytensor as pt
17 | import Model as m
18 |
19 | tf.app.flags.DEFINE_float('gradient_proportion', 5.0, 'Proportion of gradietn mixture RECO/DRAG')
20 | tf.app.flags.DEFINE_integer('sequence_length', 25, 'size of the 1-variable-variation sequencies')
21 |
22 | FLAGS = tf.app.flags.FLAGS
23 |
24 | DEV = False
25 |
26 |
27 | def _clamp(encoding, filter):
28 | filter_neg = np.ones(len(filter), dtype=filter.dtype) - filter
29 | # print('\nf_n', filter_neg)
30 | avg = encoding.mean(axis=0)*filter_neg
31 | # print('avg', avg, encoding[0])
32 | # print('avg', encoding.mean(axis=0), encoding[-1], encoding[-1]-avg)
33 | grad = encoding*filter_neg - avg
34 | encoding = encoding * filter + avg
35 | # print('enc', encoding[0], encoding[1])
36 | # print(np.hstack((encoding, grad)))
37 | # print('vae', grad[0], grad[1])
38 | return encoding, grad
39 |
40 |
41 | def _declamp_grad(vae_grad, reco_grad, filter):
42 | # print('vae, reco', np.abs(vae_grad).mean(), np.abs((reco_grad*filter)).mean())
43 | res = vae_grad/FLAGS.gradient_proportion + reco_grad*filter
44 | # res = vae_grad + reco_grad*filter
45 | #print('\nvae: %s\nrec: %s\nres %s' % (ut.print_float_list(vae_grad[1]),
46 | # ut.print_float_list(reco_grad[1]),
47 | # ut.print_float_list(res[0])))
48 | return res
49 |
50 |
51 | class IGNModel(m.Model):
52 | model_id = 'ign'
53 | decoder_scope = 'dec'
54 | encoder_scope = 'enc'
55 |
56 | layer_narrow = 2
57 | layer_encoder = 40
58 | layer_decoder = 40
59 |
60 | _image_shape = None
61 | _batch_shape = None
62 |
63 | # placeholders
64 | _input = None
65 | _encoding = None
66 |
67 | _clamped = None
68 | _reconstruction = None
69 | _clamped_grad = None
70 |
71 | # variables
72 | _clamped_variable = None
73 |
74 | # operations
75 | _encode = None
76 | _encoder_loss = None
77 | _opt_encoder = None
78 | _train_encoder = None
79 |
80 | _decode = None
81 | _decoder_loss = None
82 | _opt_decoder = None
83 | _train_decoder = None
84 |
85 | _step = None
86 | _current_step = None
87 | _visualize_op = None
88 |
89 | def __init__(self,
90 | weight_init=None,
91 | activation=act.sigmoid,
92 | optimizer=tf.train.AdamOptimizer):
93 | super(IGNModel, self).__init__()
94 | FLAGS.batch_size = FLAGS.sequence_length
95 | self._weight_init = weight_init
96 | self._activation = activation
97 | self._optimizer = optimizer
98 | if FLAGS.load_from_checkpoint:
99 | self.load_meta(FLAGS.load_from_checkpoint)
100 |
101 | def get_layer_info(self):
102 | return [self.layer_encoder, self.layer_narrow, self.layer_decoder]
103 |
104 | def get_meta(self, meta=None):
105 | meta = super(IGNModel, self).get_meta(meta=meta)
106 | meta['div'] = FLAGS.gradient_proportion
107 | return meta
108 |
109 | def load_meta(self, save_path):
110 | meta = super(IGNModel, self).load_meta(save_path)
111 | self._weight_init = meta['init']
112 | self._optimizer = tf.train.AdadeltaOptimizer \
113 | if 'Adam' in meta['opt'] \
114 | else tf.train.AdadeltaOptimizer
115 | self._activation = act.sigmoid
116 | self.layer_encoder = meta['h'][0]
117 | self.layer_narrow = meta['h'][1]
118 | self.layer_decoder = meta['h'][2]
119 | FLAGS.gradient_proportion = float(meta['div'])
120 | ut.configure_folders(FLAGS, self.get_meta())
121 | return meta
122 |
123 | # MODEL
124 |
125 | def build_model(self):
126 | tf.reset_default_graph()
127 | self._batch_shape = inp.get_batch_shape(FLAGS.batch_size, FLAGS.input_path)
128 | self._current_step = tf.Variable(0, trainable=False, name='global_step')
129 | self._step = tf.assign(self._current_step, self._current_step + 1)
130 | with pt.defaults_scope(activation_fn=self._activation.func):
131 | with pt.defaults_scope(phase=pt.Phase.train):
132 | with tf.variable_scope(self.encoder_scope):
133 | self._build_encoder()
134 | with tf.variable_scope(self.decoder_scope):
135 | self._build_decoder()
136 |
137 | def _build_encoder(self):
138 | """Construct encoder network: placeholders, operations, optimizer"""
139 | self._input = tf.placeholder(tf.float32, self._batch_shape, name='input')
140 | self._encoding = tf.placeholder(tf.float32, (FLAGS.batch_size, self.layer_narrow), name='encoding')
141 |
142 | self._encode = (pt.wrap(self._input)
143 | .flatten()
144 | .fully_connected(self.layer_encoder, name='enc_hidden')
145 | .fully_connected(self.layer_narrow, name='narrow'))
146 |
147 | # variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.encoder_scope)
148 | self._encoder_loss = self._encode.l1_regression(pt.wrap(self._encoding))
149 | ut.print_info('new learning rate: %.8f (%f)' % (FLAGS.learning_rate/FLAGS.batch_size, FLAGS.learning_rate))
150 | self._opt_encoder = self._optimizer(learning_rate=FLAGS.learning_rate/FLAGS.batch_size)
151 | self._train_encoder = self._opt_encoder.minimize(self._encoder_loss)
152 |
153 | def _build_decoder(self, weight_init=tf.truncated_normal):
154 | """Construct decoder network: placeholders, operations, optimizer,
155 | extract gradient back-prop for encoding layer"""
156 | self._clamped = tf.placeholder(tf.float32, (FLAGS.batch_size, self.layer_narrow))
157 | self._reconstruction = tf.placeholder(tf.float32, self._batch_shape)
158 |
159 | clamped_init = np.zeros((FLAGS.batch_size, self.layer_narrow), dtype=np.float32)
160 | self._clamped_variable = tf.Variable(clamped_init, name='clamped')
161 | self._assign_clamped = tf.assign(self._clamped_variable, self._clamped)
162 |
163 | # http://stackoverflow.com/questions/40194389/how-to-propagate-gradient-into-a-variable-after-assign-operation
164 | self._decode = (
165 | pt.wrap(self._clamped_variable)
166 | .fully_connected(self.layer_decoder, name='decoder_1')
167 | .fully_connected(np.prod(self._image_shape), init=weight_init, name='output')
168 | .reshape(self._batch_shape))
169 |
170 | # variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.decoder_scope)
171 | self._decoder_loss = self._build_reco_loss(self._reconstruction)
172 | self._opt_decoder = self._optimizer(learning_rate=FLAGS.learning_rate)
173 | self._train_decoder = self._opt_decoder.minimize(self._decoder_loss)
174 |
175 | self._clamped_grad, = tf.gradients(self._decoder_loss, [self._clamped_variable])
176 |
177 | # DATA
178 |
179 | def fetch_datasets(self, activation_func_bounds):
180 | original_data, filters = inp.get_images(FLAGS.input_path)
181 | assert len(filters) == len(original_data)
182 | original_data, filters = self.bloody_hack_filterbatches(original_data, filters)
183 | ut.print_info('shapes. data, filters: %s' % str((original_data.shape, filters.shape)))
184 |
185 | original_data = inp.rescale_ds(original_data, activation_func_bounds.min, activation_func_bounds.max)
186 | self._image_shape = inp.get_image_shape(FLAGS.input_path)
187 |
188 | if DEV:
189 | original_data = original_data[:300]
190 |
191 | self.epoch_size = math.ceil(len(original_data) / FLAGS.batch_size)
192 | self.test_size = math.ceil(len(original_data) / FLAGS.batch_size)
193 | return original_data, filters
194 |
195 | def bloody_hack_filterbatches(self, original_data, filters):
196 | print(filters)
197 | survivers = np.zeros(len(filters), dtype=np.uint8)
198 | j, prev = 0, None
199 | for _, f in enumerate(filters):
200 | if prev is None or prev[0] == f[0] and prev[1] == f[1]:
201 | j += 1
202 | else:
203 | k = j // FLAGS.batch_size
204 | for i in range(k):
205 | start = _ - j + math.ceil(j / k * i)
206 | survivers[start:start + FLAGS.batch_size] += 1
207 | # print(j, survivers[_-j:_])
208 | j = 0
209 | prev = f
210 | original_data = np.asarray([x for i, x in enumerate(original_data) if survivers[i] > 0])
211 | filters = np.asarray([x for i, x in enumerate(filters) if survivers[i] > 0])
212 | return original_data, filters
213 |
214 | def _get_epoch_dataset(self):
215 | ds, filters = self._get_blurred_dataset(), self._filters
216 | # permute
217 | (train_set, filters), permutation = inp.permute_data_in_series((ds, filters), FLAGS.batch_size, allow_shift=False)
218 | # construct feed
219 | feed = pt.train.feed_numpy(FLAGS.batch_size, train_set, filters)
220 | return feed, permutation
221 |
222 | # TRAIN
223 |
224 | def train(self, epochs_to_train=5):
225 | meta = self.get_meta()
226 | ut.print_time('train started: \n%s' % ut.to_file_name(meta))
227 | # return meta, np.random.randn(epochs_to_train)
228 | ut.configure_folders(FLAGS, meta)
229 |
230 | self._dataset, self._filters = self.fetch_datasets(self._activation)
231 | self.build_model()
232 | self._register_training_start()
233 |
234 | with tf.Session() as sess:
235 | sess.run(tf.initialize_all_variables())
236 | self._saver = tf.train.Saver()
237 |
238 | if FLAGS.load_state and os.path.exists(self.get_checkpoint_path()):
239 | self._saver.restore(sess, self.get_checkpoint_path())
240 | ut.print_info('Restored requested. Previous epoch: %d' % self.get_past_epochs(), color=31)
241 |
242 | # MAIN LOOP
243 | for current_epoch in xrange(epochs_to_train):
244 |
245 | feed, permutation = self._get_epoch_dataset()
246 | for _, batch in enumerate(feed):
247 | filter = batch[1][0]
248 | assert batch[1][0,0] == batch[1][-1,0]
249 | encoding, = sess.run([self._encode], feed_dict={self._input: batch[0]}) # 1.1 encode forward
250 | clamped_enc, vae_grad = _clamp(encoding, filter) # 1.2 # clamp
251 |
252 | sess.run(self._assign_clamped, feed_dict={self._clamped:clamped_enc})
253 | reconstruction, loss, clamped_gradient, _ = sess.run( # 2.1 decode forward+backward
254 | [self._decode, self._decoder_loss, self._clamped_grad, self._train_decoder],
255 | feed_dict={self._clamped: clamped_enc, self._reconstruction: batch[0]})
256 |
257 | declamped_grad = _declamp_grad(vae_grad, clamped_gradient, filter) # 2.2 prepare gradient
258 | _, step = sess.run( # 3.0 encode backward path
259 | [self._train_encoder, self._step],
260 | feed_dict={self._input: batch[0], self._encoding: encoding-declamped_grad}) # Profit
261 |
262 | self._register_batch(batch, encoding, reconstruction, loss)
263 | self._register_epoch(current_epoch, epochs_to_train, permutation, sess)
264 | self._writer = tf.train.SummaryWriter(FLAGS.logdir, sess.graph)
265 | meta = self._register_training()
266 | return meta, self._stats['epoch_accuracy']
267 |
268 |
269 |
270 | if __name__ == '__main__':
271 | # FLAGS.load_from_checkpoint = './tmp/doom_bs__act|sigmoid__bs|20__h|500|5|500__init|na__inp|cbd4__lr|0.0004__opt|AO'
272 | epochs = 300
273 | import sys
274 |
275 | model = IGNModel()
276 | args = dict([arg.split('=', maxsplit=1) for arg in sys.argv[1:]])
277 | if len(args) == 0:
278 | global DEV
279 | DEV = False
280 | print('DEVELOPMENT MODE ON')
281 | print(args)
282 | if 'epochs' in args:
283 | epochs = int(args['epochs'])
284 | ut.print_info('epochs: %d' % epochs, color=36)
285 | if 'sigma' in args:
286 | FLAGS.sigma = int(args['sigma'])
287 | if 'suffix' in args:
288 | FLAGS.suffix = args['suffix']
289 | if 'input' in args:
290 | parts = FLAGS.input_path.split('/')
291 | parts[-3] = args['input']
292 | FLAGS.input_path = '/'.join(parts)
293 | ut.print_info('input %s' % FLAGS.input_path, color=36)
294 | if 'h' in args:
295 | layers = list(map(int, args['h'].split('/')))
296 | ut.print_info('layers %s' % str(layers), color=36)
297 | model.set_layer_sizes(layers)
298 | if 'divider' in args:
299 | FLAGS.drag_divider = float(args['divider'])
300 | if 'lr' in args:
301 | FLAGS.learning_rate = float(args['lr'])
302 |
303 |
304 | model.train(epochs)
305 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2016 yselivonchyk
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Model.py:
--------------------------------------------------------------------------------
1 | """MNIST Autoencoder. """
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import tensorflow as tf
8 | import json
9 | import os
10 | import numpy as np
11 | import utils as ut
12 | import input as inp
13 | import tools.checkpoint_utils as ch_utils
14 | import visualization as vis
15 | import matplotlib.pyplot as plt
16 | import time
17 |
18 | tf.app.flags.DEFINE_string('suffix', 'run', 'Suffix to use to distinguish models by purpose')
19 | tf.app.flags.DEFINE_string('input_path', '../data/tmp/grid03.14.c.tar.gz', 'input folder')
20 | tf.app.flags.DEFINE_string('test_path', '../data/tmp/grid03.14.c.tar.gz', 'test set folder')
21 | tf.app.flags.DEFINE_float('test_max', 10000, 'max numer of exampes in the test set')
22 | tf.app.flags.DEFINE_string('save_path', './tmp/checkpoint', 'Where to save the model checkpoints.')
23 | tf.app.flags.DEFINE_string('logdir', '', 'where to save logs.')
24 | tf.app.flags.DEFINE_string('load_from_checkpoint', None, 'Load model state from particular checkpoint')
25 |
26 | tf.app.flags.DEFINE_integer('max_epochs', 50, 'Train for at most this number of epochs')
27 | tf.app.flags.DEFINE_integer('epoch_size', 100, 'Number of batches per epoch')
28 | tf.app.flags.DEFINE_integer('test_size', 0, 'Number of test batches per epoch')
29 | tf.app.flags.DEFINE_integer('save_every', 250, 'Save model state every INT epochs')
30 | tf.app.flags.DEFINE_integer('save_encodings_every', 5, 'Save encoding and visualizations every')
31 | tf.app.flags.DEFINE_boolean('load_state', True, 'Load state if possible ')
32 |
33 | tf.app.flags.DEFINE_integer('batch_size', 128, 'Batch size')
34 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'Create visualization of ')
35 |
36 | tf.app.flags.DEFINE_float('dropout', 0.0, 'Dropout probability of pre-narrow units')
37 |
38 | tf.app.flags.DEFINE_float('blur', 5.0, 'Max sigma value for Gaussian blur applied to training set')
39 | tf.app.flags.DEFINE_integer('blur_decrease', 50000, 'Decrease image blur every X steps')
40 |
41 | tf.app.flags.DEFINE_boolean('dev', False, 'Indicate that model is in the development mode')
42 |
43 | FLAGS = tf.app.flags.FLAGS
44 | slim = tf.contrib.slim
45 |
46 | DEV = False
47 |
48 |
49 | def is_stopping_point(current_epoch, epochs_to_train, stop_every=None, stop_x_times=None,
50 | stop_on_last=True):
51 | if stop_on_last and current_epoch + 1 == epochs_to_train:
52 | return True
53 | if stop_x_times is not None:
54 | return current_epoch % np.ceil(epochs_to_train / float(FLAGS.vis_substeps)) == 0
55 | if stop_every is not None:
56 | return (current_epoch + 1) % stop_every == 0
57 |
58 |
59 | def get_variable(name):
60 | assert FLAGS.load_from_checkpoint
61 | var = ch_utils.load_variable(tf.train.latest_checkpoint(FLAGS.load_from_checkpoint), name)
62 | return var
63 |
64 |
65 | def get_every_dataset():
66 | all_data = [x[0] for x in os.walk('../data/tmp_grey/') if 'img' in x[0]]
67 | print(all_data)
68 | return all_data
69 |
70 |
71 | class Model:
72 | model_id = 'base'
73 | dataset = None
74 | test_set = None
75 |
76 | _writer, _saver = None, None
77 | _dataset, _filters = None, None
78 |
79 | def get_layer_info(self):
80 | return [self.layer_encoder, self.layer_narrow, self.layer_decoder]
81 |
82 | # MODEL
83 |
84 | def build_model(self):
85 | pass
86 |
87 | def _build_encoder(self):
88 | pass
89 |
90 | def _build_decoder(self, weight_init=tf.truncated_normal):
91 | pass
92 |
93 | def _build_reco_loss(self, output_placeholder):
94 | error = self._decode - slim.flatten(output_placeholder)
95 | return tf.nn.l2_loss(error, name='reco_loss')
96 |
97 | def train(self, epochs_to_train=5):
98 | pass
99 |
100 | # META
101 |
102 | def get_meta(self, meta=None):
103 | meta = meta if meta else {}
104 |
105 | meta['postf'] = self.model_id
106 | meta['a'] = 's'
107 | meta['lr'] = FLAGS.learning_rate
108 | meta['init'] = self._weight_init
109 | meta['bs'] = FLAGS.batch_size
110 | meta['h'] = self.get_layer_info()
111 | meta['opt'] = self._optimizer
112 | meta['inp'] = inp.get_input_name(FLAGS.input_path)
113 | meta['do'] = FLAGS.dropout
114 | return meta
115 |
116 | def save_meta(self, meta=None):
117 | if meta is None:
118 | meta = self.get_meta()
119 |
120 | ut.configure_folders(FLAGS, meta)
121 | meta['a'] = 's'
122 | meta['opt'] = str(meta['opt']).split('.')[-1][:-2]
123 | meta['input_path'] = FLAGS.input_path
124 | path = os.path.join(FLAGS.save_path, 'meta.txt')
125 | json.dump(meta, open(path, 'w'))
126 |
127 | def load_meta(self, save_path):
128 | path = os.path.join(save_path, 'meta.txt')
129 | meta = json.load(open(path, 'r'))
130 | FLAGS.save_path = save_path
131 | FLAGS.batch_size = meta['bs']
132 | FLAGS.input_path = meta['input_path']
133 | FLAGS.learning_rate = meta['lr']
134 | FLAGS.load_state = True
135 | FLAGS.dropout = float(meta['do'])
136 | return meta
137 |
138 | # DATA
139 |
140 | _blurred_dataset, _last_blur = None, 0
141 |
142 | def _get_blur_sigma(self, step=None):
143 | step = step if step is not None else self._current_step.eval()
144 | calculated_sigma = FLAGS.blur - int(10 * step / FLAGS.blur_decrease) / 10.0
145 | return max(0, calculated_sigma)
146 |
147 | def _get_blurred_dataset(self):
148 | if FLAGS.blur != 0:
149 | current_sigma = self._get_blur_sigma()
150 | if current_sigma != self._last_blur:
151 | self._last_blur = current_sigma
152 | self._blurred_dataset = inp.apply_gaussian(self.dataset, sigma=current_sigma)
153 | return self._blurred_dataset if self._blurred_dataset is not None else self.dataset
154 |
155 | # MISC
156 |
157 | def get_past_epochs(self):
158 | return int(self._current_step.eval() / FLAGS.epoch_size)
159 |
160 | @staticmethod
161 | def get_checkpoint_path():
162 | return os.path.join(FLAGS.save_path, '-9999.chpt')
163 |
164 | # OUTPUTS
165 | @staticmethod
166 | def _get_stats_template():
167 | return {
168 | 'batch': [],
169 | 'input': [],
170 | 'encoding': [],
171 | 'reconstruction': [],
172 | 'total_loss': 0,
173 | 'start': time.time()
174 | }
175 |
176 | _epoch_stats = None
177 | _stats = None
178 |
179 | @ut.timeit
180 | def restore_model(self, session):
181 | self._saver = tf.train.Saver()
182 | latest_checkpoint = tf.train.latest_checkpoint(self.get_checkpoint_path()[:-10])
183 | ut.print_info("latest checkpoint: %s" % latest_checkpoint)
184 | if FLAGS.load_state and latest_checkpoint is not None:
185 | self._saver.restore(session, latest_checkpoint)
186 | ut.print_info('Restored requested. Previous epoch: %d' % self.get_past_epochs(), color=31)
187 |
188 | def _register_training_start(self, sess):
189 | self.summary_writer = tf.summary.FileWriter('/tmp/train', sess.graph)
190 |
191 | self._epoch_stats = self._get_stats_template()
192 | self._stats = {
193 | 'epoch_accuracy': [],
194 | 'epoch_reconstructions': [],
195 | 'permutation': None
196 | }
197 |
198 | if FLAGS.dev:
199 | plt.ion()
200 | plt.show()
201 |
202 | # @ut.timeit
203 | def _register_batch(self, loss, batch=None, encoding=None, reconstruction=None, step=None):
204 | self._epoch_stats['total_loss'] += loss
205 | if FLAGS.dev:
206 | assert batch is not None and reconstruction is not None
207 | original = batch[0][:, 0]
208 | vis.plot_reconstruction(original, reconstruction, interactive=True)
209 |
210 | MAX_IMAGES = 10
211 |
212 | # @ut.timeit
213 | def _register_epoch(self, epoch, total_epochs, elapsed, sess):
214 | if is_stopping_point(epoch, total_epochs, FLAGS.save_every):
215 | self._saver.save(sess, self.get_checkpoint_path())
216 |
217 | accuracy = 100000 * np.sqrt(self._epoch_stats['total_loss'] / np.prod(self._batch_shape) / FLAGS.epoch_size)
218 |
219 | if is_stopping_point(epoch, total_epochs, FLAGS.save_encodings_every):
220 | digest = self.evaluate(sess, take=self.MAX_IMAGES)
221 | data = {
222 | 'enc': np.asarray(digest[0]),
223 | 'rec': np.asarray(digest[1]),
224 | 'blu': np.asarray(digest[2][:self.MAX_IMAGES])
225 | }
226 |
227 | meta = {'suf': 'encodings', 'e': '%06d' % int(self.get_past_epochs()), 'er': int(accuracy)}
228 | projection_file = ut.to_file_name(meta, FLAGS.save_path)
229 | np.save(projection_file, data)
230 | vis.plot_encoding_crosssection(data['enc'], FLAGS.save_path, meta, data['blu'], data['rec'])
231 |
232 | self._stats['epoch_accuracy'].append(accuracy)
233 | self.print_epoch_info(accuracy, epoch, total_epochs, elapsed)
234 | if epoch + 1 != total_epochs:
235 | self._epoch_stats = self._get_stats_template()
236 |
237 | @ut.timeit
238 | def _register_training(self):
239 | best_acc = np.min(self._stats['epoch_accuracy'])
240 | meta = self.get_meta()
241 | meta['acu'] = int(best_acc)
242 | meta['e'] = self.get_past_epochs()
243 | ut.print_time('Best Quality: %f for %s' % (best_acc, ut.to_file_name(meta)))
244 | self.summary_writer.close()
245 | return meta
246 |
247 | def print_epoch_info(self, accuracy, current_epoch, epochs, elapsed):
248 | epochs_past = self.get_past_epochs() - current_epoch
249 | accuracy_info = '' if accuracy is None else '| accuracy %d' % int(accuracy)
250 | epoch_past_info = '' if epochs_past is None else '+%d' % (epochs_past - 1)
251 | epoch_count = 'Epochs %2d/%d%s' % (current_epoch + 1, epochs, epoch_past_info)
252 | time_info = '%2dms/bt' % (elapsed / FLAGS.epoch_size * 1000)
253 |
254 | info_string = ' '.join([
255 | epoch_count,
256 | accuracy_info,
257 | time_info])
258 |
259 | ut.print_time(info_string, same_line=True)
260 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DCIGN_tensorflow
2 | Deep Convolutional Inverse Graphics network (DCIGN) implementation with Tensorflow
3 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yselivonchyk/TensorFlow_DCIGN/ff8d85f3a7b7ca1e5c3f50ff003a1c09a70067cd/__init__.py
--------------------------------------------------------------------------------
/activation_functions.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | sigmoid = type('Fake', (object,), { "func": tf.nn.sigmoid, "min": 0, 'max': 1})
4 | tanh = type('Fake', (object,), { "func": tf.nn.tanh, "min": -1, 'max': 1})
5 | relu = type('Fake', (object,), { "func": tf.nn.relu, "min": 0, 'max': 1})
6 |
--------------------------------------------------------------------------------
/autoencoder.py:
--------------------------------------------------------------------------------
1 | """MNIST Autoencoder. """
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import os
8 | # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
9 | import numpy as np
10 | import tensorflow as tf
11 | import utils as ut
12 | import input as inp
13 | import visualization as vis
14 | import matplotlib.pyplot as plt
15 | import time
16 | import sys
17 | import getch
18 | import model_interpreter as interpreter
19 | import network_utils as nut
20 | import math
21 | from tensorflow.contrib.tensorboard.plugins import projector
22 | from Bunch import Bunch
23 |
24 |
25 | tf.app.flags.DEFINE_string('input_path', '../data/tmp/grid03.14.c.tar.gz', 'input folder')
26 | tf.app.flags.DEFINE_string('input_name', '', 'input folder')
27 | tf.app.flags.DEFINE_string('test_path', '', 'test set folder')
28 | tf.app.flags.DEFINE_string('net', 'f100-f3', 'model configuration')
29 | tf.app.flags.DEFINE_string('model', 'noise', 'Type of the model to use: Autoencoder (ae)'
30 | 'WhatWhereAe (ww) U-netAe (u)')
31 | tf.app.flags.DEFINE_string('postfix', '', 'Postfix for the training folder')
32 |
33 | tf.app.flags.DEFINE_float('alpha', 10, 'Predictive reconstruction loss weight')
34 | tf.app.flags.DEFINE_float('beta', 0.0005, 'Reconstruction from noisy data loss weight')
35 | tf.app.flags.DEFINE_float('epsilon', 0.000001,
36 | 'Diameter of epsilon sphere comparing to distance to a neighbour. <= 0.5')
37 | tf.app.flags.DEFINE_float('gamma', 50., 'Loss weight for large distances')
38 | tf.app.flags.DEFINE_float('distance', 0.01, 'Maximum allowed interpoint distance')
39 | tf.app.flags.DEFINE_float('delta', 1., 'Loss weight for stacked objective')
40 |
41 | tf.app.flags.DEFINE_string('comment', '', 'Comment to leave by the model')
42 |
43 | tf.app.flags.DEFINE_float('test_max', 10000, 'max number of examples in the test set')
44 |
45 | tf.app.flags.DEFINE_integer('max_epochs', 0, 'Train for at most this number of epochs')
46 | tf.app.flags.DEFINE_integer('save_every', 250, 'Save model state every INT epochs')
47 | tf.app.flags.DEFINE_integer('eval_every', 25, 'Save encoding and visualizations every')
48 | tf.app.flags.DEFINE_integer('visualiza_max', 10, 'Max pairs to show on visualization')
49 | tf.app.flags.DEFINE_boolean('load_state', True, 'Load state if possible ')
50 | tf.app.flags.DEFINE_boolean('kill_depth', False, 'Ignore depth information')
51 | tf.app.flags.DEFINE_boolean('dev', False, 'Indicate development mode')
52 | tf.app.flags.DEFINE_integer('batch_size', 128, 'Batch size')
53 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'Create visualization of ')
54 |
55 | tf.app.flags.DEFINE_float('blur', 5.0, 'Max sigma value for Gaussian blur applied to training set')
56 | tf.app.flags.DEFINE_boolean('new_blur', False, 'Use data augmentation as blur info')
57 | tf.app.flags.DEFINE_integer('blur_decrease', 10000, 'Decrease image blur every X steps')
58 |
59 | FLAGS = tf.app.flags.FLAGS
60 | slim = tf.contrib.slim
61 |
62 |
63 | AUTOENCODER = 'ae'
64 | PREDICTIVE = 'pred'
65 | DENOISING = 'noise'
66 |
67 | CHECKPOINT_NAME = '-9999.chpt'
68 | EMB_SUFFIX = '_embedding'
69 |
70 |
71 | def is_stopping_point(current_epoch, epochs_to_train, stop_every=None, stop_x_times=None,
72 | stop_on_last=True):
73 | if stop_on_last and current_epoch + 1 == epochs_to_train:
74 | return True
75 | if stop_x_times is not None:
76 | return current_epoch % np.ceil(epochs_to_train / float(stop_x_times)) == 0
77 | if stop_every is not None:
78 | return (current_epoch + 1) % stop_every == 0
79 |
80 |
81 | def _fetch_dataset(path, take=None):
82 | dataset = inp.read_ds_zip(path) # read
83 | take = len(dataset) if take is None else take
84 | dataset = dataset[:take]
85 | # print(dataset.dtype, dataset.shape, np.min(dataset), np.max(dataset))
86 | # dataset = inp.rescale_ds(dataset, 0, 1)
87 | if FLAGS.kill_depth:
88 | dataset[..., -1] = 0
89 | ut.print_info('DS fetch: %8d (%s)' % (len(dataset), path))
90 | return dataset
91 |
92 |
93 | def l2(x):
94 | l = x.get_shape().as_list()[0]
95 | return tf.reshape(tf.sqrt(tf.reduce_sum(x ** 2, axis=1)), (l, 1))
96 |
97 |
98 | def get_stats_template():
99 | return Bunch(
100 | batch=[],
101 | input=[],
102 | encoding=[],
103 | reconstruction=[],
104 | total_loss=0.,
105 | start=time.time())
106 |
107 |
108 | def guard_nan(x):
109 | return x if not math.isnan(x) else -1.
110 |
111 |
112 | def _blur_expand(input):
113 | k_size = 9
114 | kernels = [2, 4, 6]
115 | channels = [input] + [nut.blur_gaussian(input, k, k_size)[0] for k in kernels]
116 | res = tf.concat(channels, axis=3)
117 | return res
118 |
119 |
120 | class Autoencoder:
121 | train_set, test_set = None, None
122 | permutation = None
123 | batch_shape = None
124 | epoch_size = None
125 |
126 | input, target = None, None # AE placeholders
127 | encode, decode = None, None # AE operations
128 | model = None # interpreted model
129 |
130 | encoding = None # AE predictive evaluation placeholder
131 | eval_decode, eval_loss = None, None # AE evaluation
132 |
133 | inputs, targets = None, None # Noise/Predictive placeholders
134 | raw_inputs, raw_targets = None, None # inputs in network-friendly representation
135 | models = None # Noise/Predictive interpreted models
136 |
137 | optimizer, _train = None, None
138 | loss_ae, loss_reco, loss_pred, loss_dn = None, None, None, None # Objectives
139 | loss_total = None
140 | losses = []
141 |
142 | step = None # operation
143 | step_var = None # variable
144 |
145 | vis_summary, vis_placeholder = None, None
146 | image_summaries = None
147 | visualization_batch_perm = None
148 |
149 |
150 | def __init__(self, optimizer=tf.train.AdamOptimizer, need_forlders=True):
151 | self.optimizer_constructor = optimizer
152 | FLAGS.input_name = inp.get_input_name(FLAGS.input_path)
153 | if need_forlders:
154 | ut.configure_folders(FLAGS)
155 | ut.print_flags(FLAGS)
156 |
157 | # MISC
158 |
159 |
160 | def get_past_epochs(self):
161 | return int(self.step.eval() / self.epoch_size)
162 |
163 | @staticmethod
164 | def get_checkpoint_path():
165 | # print(os.path.join(FLAGS.save_path, CHECKPOINT_NAME), len(CHECKPOINT_NAME))
166 | return os.path.join(FLAGS.save_path, CHECKPOINT_NAME)
167 |
168 | def get_latest_checkpoint(self):
169 | return tf.train.latest_checkpoint(
170 | self.get_checkpoint_path()[:-len(EMB_SUFFIX)],
171 | latest_filename='checkpoint'
172 | )
173 |
174 |
175 | # DATA
176 |
177 |
178 | def fetch_datasets(self):
179 | if FLAGS.max_epochs == 0:
180 | FLAGS.input_path = FLAGS.test_path
181 | self.train_set = _fetch_dataset(FLAGS.input_path)
182 | self.epoch_size = int(self.train_set.shape[0] / FLAGS.batch_size)
183 | self.batch_shape = [FLAGS.batch_size] + list(self.train_set.shape[1:])
184 |
185 | reuse_train = FLAGS.test_path == FLAGS.input_path or FLAGS.test_path == ''
186 | self.test_set = self.train_set.copy() if reuse_train else _fetch_dataset(FLAGS.test_path)
187 | take_test = int(FLAGS.test_max) if FLAGS.test_max > 1 else int(FLAGS.test_max * len(self.test_set))
188 | ut.print_info('take %d from test' % take_test)
189 | self.test_set = self.test_set[:take_test]
190 |
191 | def _batch_generator(self, x=None, y=None, shuffle=True, batches=None):
192 | """Returns BATCH_SIZE of couples of subsequent images"""
193 | x = x if x is not None else self._get_blurred_dataset()
194 | y = y if y is not None else x
195 | batches = batches if batches is not None else int(np.floor(len(x) / FLAGS.batch_size))
196 | self.permutation = np.arange(len(x))
197 | self.permutation = self.permutation if not shuffle else np.random.permutation(self.permutation)
198 |
199 | for i in range(batches):
200 | batch_indexes = self.permutation[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size]
201 | # batch = np.stack((dataset[batch_indexes], dataset[batch_indexes + 1], dataset[batch_indexes + 2]), axis=1)
202 | yield x[batch_indexes], y[batch_indexes]
203 |
204 | def _batch_permutation_generator(self, length, start=0, shuffle=True, batches=None):
205 | self.permutation = np.arange(length) + start
206 | self.permutation = self.permutation if not shuffle else np.random.permutation(self.permutation)
207 | for i in range(int(length/FLAGS.batch_size)):
208 | if batches is not None and i >= batches:
209 | break
210 | yield self.permutation[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size]
211 |
212 | _blurred_dataset, _last_blur = None, 0
213 |
214 | def _get_blur_sigma(self):
215 | calculated_sigma = FLAGS.blur - int(10 * self.step.eval() / FLAGS.blur_decrease) / 10.0
216 | return max(0, calculated_sigma)
217 |
218 | # @ut.timeit
219 | def _get_blurred_dataset(self):
220 | if FLAGS.blur != 0:
221 | current_sigma = self._get_blur_sigma()
222 | if current_sigma != self._last_blur:
223 | # print(self._last_blur, current_sigma)
224 | self._last_blur = current_sigma
225 | self._blurred_dataset = inp.apply_gaussian(self.train_set, sigma=current_sigma)
226 | ut.print_info('blur s:%.1f[%.1f>%.1f]' % (current_sigma, self.train_set[2, 10, 10, 0], self._blurred_dataset[2, 10, 10, 0]))
227 | return self._blurred_dataset if self._blurred_dataset is not None else self.train_set
228 | return self.train_set
229 |
230 |
231 | # TRAIN
232 |
233 |
234 | def build_ae_model(self):
235 | self.input = tf.placeholder(tf.uint8, self.batch_shape, name='input')
236 | self.target = tf.placeholder(tf.uint8, self.batch_shape, name='target')
237 | self.step = tf.Variable(0, trainable=False, name='global_step')
238 | root = self._image_to_tensor(self.input)
239 | target = self._image_to_tensor(self.target)
240 |
241 | model = interpreter.build_autoencoder(root, FLAGS.net)
242 |
243 | self.encode = model.encode
244 |
245 | self.model = model
246 | self.encoding = tf.placeholder(self.encode.dtype, self.encode.get_shape(), name='encoding')
247 | eval_decode = interpreter.build_decoder(self.encoding, model.config, reuse=True)
248 | print(target, eval_decode)
249 | self.eval_loss = interpreter.l2_loss(target, eval_decode, name='predictive_reconstruction')
250 | self.eval_decode = self._tensor_to_image(eval_decode)
251 |
252 | self.loss_ae = interpreter.l2_loss(target, model.decode, name='reconstruction')
253 | self.decode = self._tensor_to_image(model.decode)
254 | self.losses = [self.loss_ae]
255 |
256 | def build_predictive_model(self):
257 | self.build_ae_model() # builds on top of AE model. Due to auxilary operations init
258 | self.inputs = tf.placeholder(tf.uint8, [3] + self.batch_shape, name='inputs')
259 | self.targets = tf.placeholder(tf.uint8, [3] + self.batch_shape, name='targets')
260 |
261 | # transform inputs
262 | self.raw_inputs = [self._image_to_tensor(self.inputs[i]) for i in range(3)]
263 | self.raw_targets = [self._image_to_tensor(self.targets[i]) for i in range(3)]
264 |
265 | # build AE objective for triplet
266 | config = self.model.config
267 | models = [interpreter.build_autoencoder(x, config) for x in self.raw_inputs]
268 | reco_losses = [1./3 * interpreter.l2_loss(models[i].decode, self.raw_targets[i]) for i in range(3)] # business as usual
269 | self.models = models
270 |
271 | # build predictive objective
272 | pred_loss_2 = self._prediction_decode(models[1].encode*2 - models[0].encode, self.raw_targets[2], models[2])
273 | pred_loss_0 = self._prediction_decode(models[1].encode*2 - models[2].encode, self.raw_targets[0], models[0])
274 |
275 | # build regularized distance objective
276 | dist_loss1 = self._distance_loss(models[1].encode - models[0].encode)
277 | dist_loss2 = self._distance_loss(models[1].encode - models[2].encode)
278 |
279 | # Stitch it all together and train
280 | self.loss_reco = tf.add_n(reco_losses)
281 | self.loss_pred = pred_loss_0 + pred_loss_2
282 | self.loss_dist = dist_loss1 + dist_loss2
283 | self.losses = [self.loss_reco, self.loss_pred]
284 |
285 | def _distance_loss(self, distances):
286 | error = tf.nn.relu(l2(distances) - FLAGS.distance ** 2)
287 | return tf.reduce_sum(error)
288 |
289 | def _prediction_decode(self, prediction, target, model):
290 | """Predict encoding t3 by encoding (t2 and t1) and expect a good reconstruction"""
291 | predict_decode = interpreter.build_decoder(prediction, self.model.config, reuse=True, masks=model.mask_list)
292 | predict_loss = 1./2 * interpreter.l2_loss(predict_decode, target, alpha=FLAGS.alpha)
293 | self.models += [predict_decode]
294 | return predict_loss * FLAGS.gamma
295 |
296 |
297 | def build_denoising_model(self):
298 | self.build_predictive_model() # builds on top of predictive model. Reuses triplet encoding
299 |
300 | # build denoising objective
301 | models = self.models
302 | self.loss_dn = self._noisy_decode(models[1])
303 | self.losses = [self.loss_reco, self.loss_pred, self.loss_dist, self.loss_dn]
304 |
305 | def _noisy_decode(self, model):
306 | """Distort middle encoding with [<= 1/3*dist(neigbour)] and demand good reconstruction"""
307 | # dist = l2(x1 - x2)
308 | # noise = dist * self.epsilon_sphere_noise()
309 | # tf.stop_gradient(noise)
310 | noise = tf.random_normal(self.model.encode.get_shape().as_list()) * FLAGS.epsilon
311 | noisy_encoding = noise + self.models[1].encode
312 | tf.stop_gradient(noisy_encoding) # or maybe here, who knows
313 | noisy_decode = interpreter.build_decoder(noisy_encoding, model.config, reuse=True, masks=model.mask_list)
314 | loss = interpreter.l2_loss(noisy_decode, self.raw_targets[1], alpha=FLAGS.beta)
315 | self.models += [noisy_decode]
316 | return loss
317 |
318 | def _tensor_to_image(self, net):
319 | with tf.name_scope('to_image'):
320 | if FLAGS.new_blur:
321 | net = net[..., :self.batch_shape[-1]]
322 | net = tf.nn.relu(net)
323 | net = tf.cast(net <= 1, net.dtype) * net * 255
324 | net = tf.cast(net, tf.uint8)
325 | return net
326 |
327 | def _image_to_tensor(self, image):
328 | with tf.name_scope('args_transform'):
329 | net = tf.cast(image, tf.float32) / 255.
330 | if FLAGS.new_blur:
331 | net = _blur_expand(net)
332 | FLAGS.blur = 0.
333 | return net
334 |
335 | def _init_optimizer(self):
336 | self.loss_total = tf.add_n(self.losses, 'loss_total')
337 | self.optimizer = self.optimizer_constructor(learning_rate=FLAGS.learning_rate)
338 | self._train = self.optimizer.minimize(self.loss_total, global_step=self.step)
339 |
340 |
341 | # MAIN
342 |
343 |
344 | def train(self):
345 | self.fetch_datasets()
346 | if FLAGS.model == AUTOENCODER:
347 | self.build_ae_model()
348 | elif FLAGS.model == PREDICTIVE:
349 | self.build_predictive_model()
350 | else:
351 | self.build_denoising_model()
352 | self._init_optimizer()
353 |
354 | with tf.Session() as sess:
355 | sess.run(tf.global_variables_initializer())
356 | self._on_training_start(sess)
357 |
358 | try:
359 | for current_epoch in range(FLAGS.max_epochs):
360 | start = time.time()
361 | full_set_blur = len(self.train_set) < 50000
362 | ds = self._get_blurred_dataset() if full_set_blur else self.train_set
363 | if FLAGS.model == AUTOENCODER:
364 |
365 | # Autoencoder Training
366 | for batch in self._batch_generator():
367 | summs, encoding, reconstruction, loss, _, step = sess.run(
368 | [self.summs_train, self.encode, self.decode, self.loss_ae, self.train_ae, self.step],
369 | feed_dict={self.input: batch[0], self.target: batch[1]}
370 | )
371 | self._on_batch_finish(summs, loss, batch, encoding, reconstruction)
372 |
373 | else:
374 |
375 | # Predictive and Denoising training
376 | for batch_indexes in self._batch_permutation_generator(len(ds)-2):
377 | batch = np.stack((ds[batch_indexes], ds[batch_indexes + 1], ds[batch_indexes + 2]))
378 | if not full_set_blur:
379 | batch = np.stack((
380 | inp.apply_gaussian(ds[batch_indexes], sigma=self._get_blur_sigma()),
381 | inp.apply_gaussian(ds[batch_indexes+1], sigma=self._get_blur_sigma()),
382 | inp.apply_gaussian(ds[batch_indexes+2], sigma=self._get_blur_sigma())
383 | ))
384 |
385 | summs, loss, _ = sess.run(
386 | [self.summs_train, self.loss_total, self._train],
387 | feed_dict={self.inputs: batch, self.targets: batch})
388 | self._on_batch_finish(summs, loss)
389 |
390 | self._on_epoch_finish(current_epoch, start, sess)
391 | self._on_training_finish(sess)
392 | except KeyboardInterrupt:
393 | self._on_training_abort(sess)
394 |
395 | def inference(self, max=10^6):
396 | self.fetch_datasets()
397 | self.build_ae_model()
398 |
399 | with tf.Session() as sess:
400 | sess.run(tf.global_variables_initializer())
401 | # nut.print_model_info()
402 | # nut.list_checkpoint_vars(self.get_latest_checkpoint().replace(EMB_SUFFIX, ''))
403 |
404 | self.saver = tf.train.Saver()
405 | self._restore_model(sess)
406 | # nut.print_model_info()
407 |
408 | encoding, decoding = None, None
409 | for i in range(len(self.train_set)):
410 | batch = np.expand_dims(self.train_set[i], axis=0)
411 | enc, dec = sess.run(
412 | [self.encode, self.decode],
413 | feed_dict={self.input: batch}
414 | )
415 |
416 | # print(enc.shape, dec.shape)
417 | encoding = enc if i == 0 else np.vstack((encoding, enc))
418 | decoding = dec if i == 0 else np.vstack((decoding, dec))
419 | print('\r%5d/%d' % (i, len(self.train_set)), end='')
420 | if i >= max:
421 | break
422 | return encoding, decoding
423 |
424 | # @ut.timeit
425 | def evaluate(self, sess, take):
426 | digest = Bunch(encoded=None, reconstructed=None, source=None,
427 | loss=.0, eval_loss=.0, dumb_loss=.0)
428 | blurred = inp.apply_gaussian(self.test_set, self._get_blur_sigma())
429 | # Encode
430 | for i, batch in enumerate(self._batch_generator(blurred, shuffle=False)):
431 | encoding = self.encode.eval(feed_dict={self.input: batch[0]})
432 | digest.encoded = ut.concatenate(digest.encoded, encoding)
433 | # Save encoding for visualization
434 | encoded_no_nan = np.nan_to_num(digest.encoded)
435 | self.embedding_assign.eval(feed_dict={self.embedding_test_ph: encoded_no_nan})
436 | try:
437 | self.embedding_saver.save(sess, self.get_checkpoint_path() + EMB_SUFFIX)
438 | except:
439 | ut.print_info("Unexpected error: %s" % str(sys.exc_info()[0]), color=33)
440 |
441 | # Calculate expected evaluation
442 | expected = digest.encoded[1:-1]*2 - digest.encoded[:-2]
443 | average = 0.5 * (digest.encoded[1:-1] + digest.encoded[:-2])
444 | digest.size = len(expected)
445 | # evaluation summaries
446 | self.summary_writer.add_summary(self.eval_summs.eval(
447 | feed_dict={self.blur_ph: self._get_blur_sigma()}),
448 | global_step=self.get_past_epochs())
449 | # evaluation losses
450 | for p in self._batch_permutation_generator(digest.size, shuffle=False):
451 | digest.loss += self.eval_loss.eval(feed_dict={self.encoding: digest.encoded[p + 2], self.target: blurred[p + 2]})
452 | digest.eval_loss += self.eval_loss.eval(feed_dict={self.encoding: expected[p], self.target: blurred[p + 2]})
453 | digest.dumb_loss += self.loss_ae.eval( feed_dict={self.input: blurred[p], self.target: blurred[p + 2]})
454 |
455 | # for batch in self._batch_generator(blurred, batches=1):
456 | # digest.source = batch[1][:take]
457 | # digest.reconstructed = self.decode.eval(feed_dict={self.input: batch[0]})[:take]
458 |
459 | # Reconstruction visualizations
460 | for p in self._batch_permutation_generator(digest.size, shuffle=True, batches=1):
461 | self.visualization_batch_perm = self.visualization_batch_perm if self.visualization_batch_perm is not None else p
462 | p = self.visualization_batch_perm
463 | digest.source = self.eval_decode.eval(feed_dict={self.encoding: expected[p]})[:take]
464 | digest.source = blurred[(p+2)[:take]]
465 | digest.reconstructed = self.eval_decode.eval(feed_dict={self.encoding: average[p]})[:take]
466 | self._eval_image_summaries(blurred[p], digest.encoded[p], average[p], expected[p])
467 |
468 | digest.dumb_loss = guard_nan(digest.dumb_loss)
469 | digest.eval_loss = guard_nan(digest.eval_loss)
470 | digest.loss = guard_nan(digest.loss)
471 | return digest
472 |
473 | def _eval_image_summaries(self, blurred_batch, actual, average, expected):
474 | """Create Tensorboard summaries with image reconstructions"""
475 | noisy = expected + np.random.randn(*expected.shape) * FLAGS.epsilon
476 |
477 | summary = self.image_summaries['orig'].eval(feed_dict={self.input: blurred_batch})
478 | self.summary_writer.add_summary(summary, global_step=self.get_past_epochs())
479 |
480 | self._eval_image_summary('midd', average)
481 | # self._eval_image_summary('reco', actual)
482 | self._eval_image_summary('pred', expected)
483 | self._eval_image_summary('nois', noisy)
484 |
485 | def _eval_image_summary(self, name, encdoding_batch):
486 | summary = self.image_summaries[name].eval(feed_dict={self.encoding: encdoding_batch})
487 | self.summary_writer.add_summary(summary, global_step=self.get_past_epochs())
488 |
489 | def _add_decoding_summary(self, name, var, collection='train'):
490 | var = var[:FLAGS.visualiza_max]
491 | var = tf.concat(tf.unstack(var), axis=0)
492 | var = tf.expand_dims(var, dim=0)
493 | color_s = tf.summary.image(name, var[..., :3], max_outputs=FLAGS.visualiza_max)
494 | var = tf.expand_dims(var[..., 3], dim=3)
495 | bw_s = tf.summary.image('depth_' + name, var, max_outputs=FLAGS.visualiza_max)
496 | return tf.summary.merge([color_s, bw_s])
497 |
498 |
499 | # TRAINING PROGRESS EVENTS
500 |
501 |
502 | def _on_training_start(self, sess):
503 | # Writers and savers
504 | self.summary_writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph)
505 | self.saver = tf.train.Saver()
506 | self._build_embedding_saver(sess)
507 | self._restore_model(sess)
508 | # Loss summaries
509 | self._build_summaries()
510 |
511 | self.epoch_stats = get_stats_template()
512 | self.stats = Bunch(
513 | epoch_accuracy=[],
514 | epoch_reconstructions=[],
515 | permutation=None
516 | )
517 | # if FLAGS.dev:
518 | # plt.ion()
519 | # plt.show()
520 |
521 | def _build_summaries(self):
522 | # losses
523 | with tf.name_scope('losses'):
524 | loss_names = ['loss_autoencoder', 'loss_predictive', 'loss_distance', 'loss_denoising']
525 | for i, loss in enumerate(self.losses):
526 | self._add_loss_summary(loss_names[i], loss)
527 | self._add_loss_summary('loss_total', self.loss_total)
528 | self.summs_train = tf.summary.merge_all('train')
529 | # reconstructions
530 | with tf.name_scope('decodings'):
531 | self.image_summaries = {
532 | 'orig': self._add_decoding_summary('0_original_input', self.input),
533 | 'reco': self._add_decoding_summary('1_reconstruction', self.eval_decode),
534 | 'pred': self._add_decoding_summary('2_prediction', self.eval_decode),
535 | 'midd': self._add_decoding_summary('3_averaged', self.eval_decode),
536 | 'nois': self._add_decoding_summary('4_noisy', self.eval_decode)
537 | }
538 | # visualization
539 | fig = vis.get_figure()
540 | fig.canvas.draw()
541 | self.vis_placeholder = tf.placeholder(tf.uint8, ut.fig2rgb_array(fig).shape)
542 | self.vis_summary = tf.summary.image('visualization', self.vis_placeholder)
543 | # embedding
544 | dists = l2(self.embedding_test[:-1] - self.embedding_test[1:])
545 | self.dist = dists
546 | metrics = []
547 |
548 | metrics.append(tf.summary.histogram('point_distance', dists))
549 | metrics.append(tf.summary.scalar('training/trajectory_length', tf.reduce_sum(dists)))
550 | self.blur_ph = tf.placeholder(dtype=tf.float32)
551 | metrics.append(tf.summary.scalar('training/blur_sigma', self.blur_ph))
552 |
553 | pred = self.embedding_test[1:-1]*2 - self.embedding_test[0:-2]
554 | pred_error = l2(pred - self.embedding_test[2:])
555 |
556 | mean_dist, mean_pred_error = tf.reduce_mean(dists), tf.reduce_mean(pred_error)
557 | improvement = (mean_dist-mean_pred_error)/mean_dist
558 |
559 | pairwise_improvement = tf.nn.relu(dists[1:] - pred_error)
560 | pairwise_improvement_bool = tf.cast(pairwise_improvement > 0, pairwise_improvement.dtype)
561 | self.pairwise_improvement_bool = pairwise_improvement_bool
562 |
563 | metrics.append(tf.summary.scalar('training/avg_dist', mean_dist))
564 | metrics.append(tf.summary.scalar('training/pred_dist', mean_pred_error))
565 | metrics.append(tf.summary.scalar('training/improvement', improvement))
566 | metrics.append(tf.summary.scalar('training/improvement_abs', tf.nn.relu(improvement)))
567 | metrics.append(tf.summary.histogram('training/improvement_abs_hist', nut.nan_to_zero(improvement)))
568 | metrics.append(tf.summary.scalar('training/improvement_pairwise', tf.reduce_mean(pairwise_improvement_bool)))
569 | metrics.append(tf.summary.histogram('training/improvement_pairwise_hist', pairwise_improvement_bool))
570 | self.eval_summs = tf.summary.merge(metrics)
571 |
572 |
573 | def _build_embedding_saver(self, sess):
574 | """To use embedding visualizer data has to be stored in variable
575 | since we would like to visualize TEST_SET, this variable should not affect
576 | common checkpoint of the model.
577 | Hence, we build a separate variable with a separate saver."""
578 | embedding_shape = [int(len(self.test_set) / FLAGS.batch_size) * FLAGS.batch_size,
579 | self.encode.get_shape().as_list()[1]]
580 | tsv_path = os.path.join(FLAGS.logdir, 'metadata.tsv')
581 |
582 | self.embedding_test_ph = tf.placeholder(tf.float32, embedding_shape, name='embedding')
583 | self.embedding_test = tf.Variable(tf.random_normal(embedding_shape), name='test_embedding', trainable=False)
584 | self.embedding_assign = self.embedding_test.assign(self.embedding_test_ph)
585 | self.embedding_saver = tf.train.Saver(var_list=[self.embedding_test])
586 |
587 | config = projector.ProjectorConfig()
588 | embedding = config.embeddings.add()
589 | embedding.tensor_name = self.embedding_test.name
590 | embedding.sprite.image_path = './sprite.png'
591 | embedding.sprite.single_image_dim.extend([80, 80])
592 | embedding.metadata_path = './metadata.tsv'
593 | projector.visualize_embeddings(self.summary_writer, config)
594 | sess.run(tf.variables_initializer([self.embedding_test], name='init_embeddings'))
595 |
596 | # build sprite image
597 | ut.images_to_sprite(self.test_set, path=os.path.join(FLAGS.logdir, 'sprite.png'))
598 | ut.generate_tsv(len(self.test_set), tsv_path)
599 |
600 | def _add_loss_summary(self, name, var, collection='train'):
601 | if var is not None:
602 | tf.summary.scalar(name, var, [collection])
603 | tf.summary.scalar('log_' + name, tf.log(var), [collection])
604 |
605 | def _restore_model(self, session):
606 | latest_checkpoint = self.get_latest_checkpoint()
607 | print(latest_checkpoint)
608 | if latest_checkpoint is not None:
609 | latest_checkpoint = latest_checkpoint.replace(EMB_SUFFIX, '')
610 | ut.print_info("latest checkpoint: %s" % latest_checkpoint)
611 | if FLAGS.load_state and latest_checkpoint is not None:
612 | self.saver.restore(session, latest_checkpoint)
613 | ut.print_info('Restored requested. Previous epoch: %d' % self.get_past_epochs(), color=31)
614 |
615 | def _on_batch_finish(self, summs, loss, batch=None, encoding=None, reconstruction=None):
616 | self.summary_writer.add_summary(summs, global_step=self.step.eval())
617 | self.epoch_stats.total_loss += loss
618 |
619 | if False:
620 | assert batch is not None and reconstruction is not None
621 | original = batch[0]
622 | vis.plot_reconstruction(original, reconstruction, interactive=True)
623 |
624 | # @ut.timeit
625 | def _on_epoch_finish(self, epoch, start_time, sess):
626 | elapsed = time.time() - start_time
627 | self.epoch_stats.total_loss = guard_nan(self.epoch_stats.total_loss)
628 | accuracy = np.nan_to_num(100000 * np.sqrt(self.epoch_stats.total_loss / np.prod(self.batch_shape) / self.epoch_size))
629 | # SAVE
630 | if is_stopping_point(epoch, FLAGS.max_epochs, FLAGS.save_every):
631 | self.saver.save(sess, self.get_checkpoint_path())
632 | # VISUALIZE
633 | if is_stopping_point(epoch, FLAGS.max_epochs, FLAGS.eval_every):
634 | evaluation = self.evaluate(sess, take=FLAGS.visualiza_max)
635 | data = {
636 | 'enc': np.asarray(evaluation.encoded),
637 | 'rec': np.asarray(evaluation.reconstructed),
638 | 'blu': np.asarray(evaluation.source)
639 | }
640 | error_info = '%d(%d.%d.%d)' % (np.nan_to_num(accuracy),
641 | np.nan_to_num(evaluation.loss)/evaluation.size,
642 | np.nan_to_num(evaluation.eval_loss)/evaluation.size,
643 | np.nan_to_num(evaluation.dumb_loss)/evaluation.size)
644 | meta = Bunch(suf='encodings', e='%06d' % int(self.get_past_epochs()), er=error_info)
645 | # print(data, meta.to_file_name(folder=FLAGS.save_path))
646 | np.save(meta.to_file_name(folder=FLAGS.save_path), data)
647 | vis.plot_encoding_crosssection(
648 | evaluation.encoded,
649 | meta.to_file_name(FLAGS.save_path, 'jpg'),
650 | evaluation.source,
651 | evaluation.reconstructed,
652 | interactive=FLAGS.dev)
653 | self._save_visualization_to_summary()
654 | self.stats.epoch_accuracy.append(accuracy)
655 | self._print_epoch_info(accuracy, epoch, FLAGS.max_epochs, elapsed)
656 | if epoch + 1 != FLAGS.max_epochs:
657 | self.epoch_stats = get_stats_template()
658 |
659 | def _save_visualization_to_summary(self):
660 | image = ut.fig2rgb_array(plt.figure(num=0))
661 | self.summary_writer.add_summary(self.vis_summary.eval(feed_dict={self.vis_placeholder: image}))
662 |
663 | def _print_epoch_info(self, accuracy, current_epoch, epochs, elapsed):
664 | epochs_past = self.get_past_epochs() - current_epoch
665 | accuracy_info = '' if accuracy is None else '| accuracy %d' % int(accuracy)
666 | epoch_past_info = '' if epochs_past is None else '+%d' % (epochs_past - 1)
667 | epoch_count = 'Epochs %2d/%d%s' % (current_epoch + 1, epochs, epoch_past_info)
668 | time_info = '%2dms/bt' % (elapsed / self.epoch_size * 1000)
669 |
670 | examples = int(np.floor(len(self.train_set) / FLAGS.batch_size))
671 | loss_info = 't.loss:%d' % (self.epoch_stats.total_loss * 100 / (examples * np.prod(self.batch_shape[1:])))
672 |
673 | info_string = ' '.join([epoch_count, accuracy_info, time_info, loss_info])
674 | ut.print_time(info_string, same_line=True)
675 |
676 | def _on_training_finish(self, sess):
677 | if FLAGS.max_epochs == 0:
678 | self._on_epoch_finish(self.get_past_epochs(), time.time(), sess)
679 | best_acc = np.min(self.stats.epoch_accuracy)
680 | ut.print_time('Best Quality: %f for %s' % (best_acc, FLAGS.net))
681 | self.summary_writer.close()
682 |
683 | def _on_training_abort(self, sess):
684 | print('Press ENTER to save the model')
685 | if getch.getch() == '\n':
686 | print('saving')
687 | self.saver.save(sess, self.get_checkpoint_path())
688 |
689 |
690 | if __name__ == '__main__':
691 | args = dict([arg.split('=', maxsplit=1) for arg in sys.argv[1:]])
692 | if len(args) <= 1:
693 | FLAGS.input_path = '../data/tmp/romb8.5.6.tar.gz'
694 | FLAGS.test_path = '../data/tmp/romb8.5.6.tar.gz'
695 | FLAGS.test_max = 2178
696 | FLAGS.max_epochs = 5
697 | FLAGS.eval_every = 1
698 | FLAGS.save_every = 1
699 | FLAGS.batch_size = 32
700 | FLAGS.blur = 0.0
701 |
702 | # FLAGS.model = 'noise'
703 | # FLAGS.beta = 1.0
704 | # FLAGS.epsilon = .000001
705 |
706 | model = Autoencoder()
707 | if FLAGS.model == 'ae':
708 | FLAGS.model = AUTOENCODER
709 | elif 'pred' in FLAGS.model:
710 | print('PREDICTIVE')
711 | FLAGS.model = PREDICTIVE
712 | elif 'noi' in FLAGS.model:
713 | print('DENOISING')
714 | FLAGS.model = DENOISING
715 | else:
716 | print('Do-di-li-doo doo-di-li-don')
717 | model.train()
718 |
--------------------------------------------------------------------------------
/data_preprocessing/generate.command:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir tmp
3 | find . -type d -depth 2 -print | grep -v 'tmp' | cpio -pd ./tmp
4 |
5 | up=$"../../../tmp/"
6 |
7 | find . -type d -depth 3 -print | grep -v 'tmp' | while read d; do
8 | stringA=$up$d
9 | len=${#stringA}
10 | len=$((len-7))
11 | echo ${stringA}
12 | stringA=${stringA::len}
13 | (cd $d && mogrify -path $stringA -resize 48x32 -crop 32x32+0+0 -quality 100 *.jpg)
14 | done
15 |
--------------------------------------------------------------------------------
/data_preprocessing/movie.command:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mencoder "mf://*.jpg" -mf fps=5 -o out.avi -ovc lavc -lavcopts vco dec=msmpeg4v2:vbitrate=640
3 |
--------------------------------------------------------------------------------
/data_preprocessing/resize.command:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #-type Grayscale
3 | #cd $d && mogrify -path $stringA -type Grayscale -resize 160x120 -crop 80x80+40+0 -quality 100 *.jpg
4 | rm -R tmp
5 | mkdir tmp
6 | find . -mindepth 2 -maxdepth 2 -type d -print | grep -v 'tmp' | cpio -pd ./tmp
7 |
8 | up=$"../../../tmp/"
9 |
10 | find . -mindepth 3 -maxdepth 3 -type d -print | grep -v 'tmp' | while read d; do
11 | stringA=$up$d
12 | len=${#stringA}
13 | len=$((len-7))
14 | echo ${stringA}
15 | stringA=${stringA::len}
16 | (cd $d && mogrify -path $stringA -resize 160x120 -crop 120x120+20+0 -quality 100 *.jpg)
17 | done
18 |
--------------------------------------------------------------------------------
/data_preprocessing/resize.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | echo hello world
3 | mogrify -path ../32_32 -resize -thumbnail 32x32 *.png
4 | mogrify -path thumbnail-directory -thumbnail 100x100 *
5 |
6 | mkdir 32_32
7 | cd 640_480
8 | mogrify -path ./../32_32 -resize 48x32 -crop 32x32+0+0 -quality 100 *.jpg
--------------------------------------------------------------------------------
/data_preprocessing/resize_grey.command:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #-type Grayscale
3 | #cd $d && mogrify -path $stringA -type Grayscale -resize 160x120 -crop 80x80+40+0 -quality 100 *.jpg
4 | rm -R tmp_grey
5 | mkdir tmp_grey
6 |
7 |
8 | find . -mindepth 2 -maxdepth 2 -type d -print | grep -v 'tmp' | cpio -pd ./tmp_grey/
9 |
10 | up=$"../../../tmp_grey/"
11 |
12 | find . -mindepth 3 -maxdepth 3 -type d -print | grep -v 'tmp' | while read d; do
13 | stringA=$up$d
14 | len=${#stringA}
15 | len=$((len-7))
16 | echo ${stringA}
17 | stringA=${stringA::len}
18 | # echo ${stringA}
19 | (cd $d && mogrify -path $stringA -resize 160x120 -crop 40x40+60+40 -fx '(r+g+b)/3' -quality 100 *.jpg)
20 | done
21 |
22 | chmod 777 tmp
23 |
--------------------------------------------------------------------------------
/data_preprocessing/resize_no_gun.command:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #-type Grayscale
3 | #cd $d && mogrify -path $stringA -type Grayscale -resize 160x120 -crop 80x80+40+0 -quality 100 *.jpg
4 | rm -R tmp
5 | mkdir tmp
6 | find . -mindepth 2 -maxdepth 2 -type d -print | grep -v 'tmp' | cpio -pd ./tmp
7 |
8 | up=$"../../../tmp/"
9 |
10 | find . -mindepth 3 -maxdepth 3 -type d -print | grep -v 'tmp' | while read d; do
11 | stringA=$up$d
12 | len=${#stringA}
13 | len=$((len-7))
14 | echo ${stringA}
15 | stringA=${stringA::len}
16 | (cd $d && mogrify -path $stringA -resize 160x120 -crop 160x80+0+0 -quality 100 *.jpg)
17 | done
18 |
19 | chmod 777 tmp
20 |
--------------------------------------------------------------------------------
/data_preprocessing/tar_all.command:
--------------------------------------------------------------------------------
1 | ext=$".tar.gz"
2 | find ./tmp/ -type d -maxdepth 1 -mindepth 1 | while read d; do
3 | tar -zcvf $d$ext $d/
4 | done
--------------------------------------------------------------------------------
/data_preprocessing/to32_32.command:
--------------------------------------------------------------------------------
1 | mkdir 32_32
2 | cd 640_480
3 | mogrify -path ./../32_32 -resize 48x32 -crop 32x32+0+0 -quality 100 *.jpg
--------------------------------------------------------------------------------
/experiments.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import utils as ut
4 | import input
5 | import DoomModel as dm
6 | import pickle
7 | from datetime import datetime as dt
8 | import sys
9 |
10 | FLAGS = tf.app.flags.FLAGS
11 |
12 |
13 | def search_learning_rate(lrs=[0.001, 0.0004, 0.0001, 0.00003,],
14 | epochs=500):
15 | FLAGS.suffix = 'grid_lr'
16 | ut.print_info('START: search_learning_rate', color=31)
17 |
18 | best_result, best_args = None, None
19 | result_summary, result_list = [], []
20 |
21 | for lr in lrs:
22 | ut.print_info('STEP: search_learning_rate', color=31)
23 | FLAGS.learning_rate = lr
24 | model = model_class()
25 | meta, accuracy_by_epoch = model.train(epochs)
26 | result_list.append((ut.to_file_name(meta), accuracy_by_epoch))
27 | best_accuracy = np.min(accuracy_by_epoch)
28 | result_summary.append('\n\r lr:%2.5f \tq:%.2f' % (lr, best_accuracy))
29 | if best_result is None or best_result > best_accuracy:
30 | best_result = best_accuracy
31 | best_args = lr
32 |
33 | meta = {'suf': 'grid_lr_bs', 'e': epochs, 'lrs': lrs, 'acu': best_result,
34 | 'bs': FLAGS.batch_size, 'h': model.get_layer_info()}
35 | pickle.dump(result_list, open('search_learning_rate%d.txt' % epochs, "wb"))
36 | ut.plot_epoch_progress(meta, result_list)
37 | print(''.join(result_summary))
38 | ut.print_info('BEST Q: %d IS ACHIEVED FOR LR: %f' % (best_result, best_args), 36)
39 |
40 |
41 | def search_batch_size(bss=[50], strides=[1, 2, 5, 20], epochs=500):
42 | FLAGS.suffix = 'grid_bs'
43 | ut.print_info('START: search_batch_size', color=31)
44 | best_result, best_args = None, None
45 | result_summary, result_list = [], []
46 |
47 | print(bss)
48 | for bs in bss:
49 | for stride in strides:
50 | ut.print_info('STEP: search_batch_size %d %d' % (bs, stride), color=31)
51 | FLAGS.batch_size = bs
52 | FLAGS.stride = stride
53 | model = model_class()
54 | start = dt.now()
55 | # meta, accuracy_by_epoch = model.train(epochs * int(bs / bss[0]))
56 | meta, accuracy_by_epoch = model.train(epochs)
57 | meta['str'] = stride
58 | meta['t'] = int((dt.now() - start).seconds)
59 | result_list.append((ut.to_file_name(meta)[22:], accuracy_by_epoch))
60 | best_accuracy = np.min(accuracy_by_epoch)
61 | result_summary.append('\n\r bs:%d \tst:%d \tq:%.2f' % (bs, stride, best_accuracy))
62 | if best_result is None or best_result > best_accuracy:
63 | best_result = best_accuracy
64 | best_args = (bs, stride)
65 |
66 | meta = {'suf': 'grid_batch_bs', 'e': epochs, 'acu': best_result,
67 | 'h': model.get_layer_info()}
68 | pickle.dump(result_list, open('search_batch_size%d.txt' % epochs, "wb"))
69 | ut.plot_epoch_progress(meta, result_list)
70 | print(''.join(result_summary))
71 |
72 | ut.print_info('BEST Q: %d IS ACHIEVED FOR bs, st: %d %d' % (best_result, best_args[0], best_args[1]), 36)
73 |
74 |
75 | def search_layer_sizes(epochs=500):
76 | FLAGS.suffix = 'grid_h'
77 | ut.print_info('START: search_layer_sizes', color=31)
78 | best_result, best_args = None, None
79 | result_summary, result_list = [], []
80 |
81 | for _, h_encoder in enumerate([300, 700, 2500]):
82 | for _, h_decoder in enumerate([300, 700, 2500]):
83 | for _, h_narrow in enumerate([3]):
84 | model = model_class()
85 | model.layer_encoder = h_encoder
86 | model.layer_narrow = h_narrow
87 | model.layer_decoder = h_decoder
88 | layer_info = str(model.get_layer_info())
89 | ut.print_info('STEP: search_layer_sizes: ' + str(layer_info), color=31)
90 |
91 | meta, accuracy_by_epoch = model.train(epochs)
92 | result_list.append((layer_info, accuracy_by_epoch))
93 | best_accuracy = np.min(accuracy_by_epoch)
94 | result_summary.append('\n\r h:%s \tq:%.2f' % (layer_info, best_accuracy))
95 | if best_result is None or best_result > best_accuracy:
96 | best_result = best_accuracy
97 | best_args = layer_info
98 |
99 | meta = {'suf': 'grid_H_bs', 'e': epochs, 'acu': best_result,
100 | 'bs': FLAGS.batch_size, 'h': model.get_layer_info()}
101 | print(''.join(result_summary))
102 | pickle.dump(result_list, open('search_layer_sizes%d.txt' % epochs, "wb"))
103 | ut.print_info('BEST Q: %d IS ACHIEVED FOR H: %s' % (best_result, best_args), 36)
104 | ut.plot_epoch_progress(meta, result_list)
105 |
106 |
107 | def search_layer_sizes_follow_up():
108 | """train further 2 best models"""
109 | FLAGS.save_every = 200
110 | for i in range(4):
111 | model = model_class()
112 | model.layer_encoder = 500
113 | model.layer_narrow = 3
114 | model.layer_decoder = 100
115 | model.train(600)
116 |
117 | model = model_class()
118 | model.layer_encoder = 500
119 | model.layer_narrow = 12
120 | model.layer_decoder = 500
121 | model.train(600)
122 |
123 |
124 | def print_reconstructions_along_with_originals():
125 | FLAGS.load_from_checkpoint = './tmp/doom_bs__act|sigmoid__bs|20__h|500|5|500__init|na__inp|cbd4__lr|0.0004__opt|AO'
126 | model = model_class()
127 | files = ut.list_encodings(FLAGS.save_path)
128 | last_encoding = files[-1]
129 | print(last_encoding)
130 | take_only = 20
131 | data = np.loadtxt(last_encoding)[0:take_only]
132 | reconstructions = model.decode(data)
133 | original, _ = input.get_images(FLAGS.input_path, at_most=take_only)
134 | ut.print_side_by_side(original, reconstructions)
135 |
136 |
137 | def train_couple_8_models():
138 | FLAGS.input_path = '../data/tmp/8_pos_delay_3/img/'
139 |
140 | model = model_class()
141 | model.set_layer_sizes([500, 5, 500])
142 | for i in range(10):
143 | model.train(1000)
144 |
145 | model = model_class()
146 | model.set_layer_sizes([1000, 10, 1000])
147 | for i in range(20):
148 | model.train(1000)
149 |
150 |
151 | if __name__ == "__main__":
152 | # run function if provided as console params
153 | epochs = 100
154 | model_class = dm.DoomModel
155 | experiment = search_learning_rate
156 |
157 | if len(sys.argv) > 1:
158 | print(sys.argv)
159 | experiment = sys.argv[1]
160 | if experiment not in locals():
161 | ut.print_info('Function "%s" not found. List of available functions:' % experiment)
162 | ut.print_info('\n'.join([x for x in locals() if 'search' in x]))
163 | exit(0)
164 | experiment = locals()[experiment]
165 | if len(sys.argv) > 2:
166 | epochs = int(sys.argv[2])
167 | if len(sys.argv) > 3:
168 | m = __import__(sys.argv[3])
169 | model_class = getattr(m, sys.argv[3])
170 |
171 | FLAGS.suffix = 'grid'
172 | # FLAGS.input_path = '../data/tmp/8_pos_delay/img/'
173 |
174 | experiment(epochs=epochs)
175 |
176 | # search_layer_sizes(epochs=epochs)
177 | # search_batch_size(epochs=epochs)
178 | # FLAGS.batch_size = 40
179 |
--------------------------------------------------------------------------------
/input.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import utils as ut
4 | import json
5 | import scipy.ndimage.filters as filters
6 | import time
7 | from PIL import Image
8 | import tarfile
9 | import io
10 |
11 |
12 | INPUT_FOLDER = '../data/circle_basic_1/img/32_32'
13 |
14 |
15 | def _is_combination_of_image_depth(folder):
16 | return '/dep' not in folder and '/img' not in folder
17 |
18 |
19 | def get_action_data(folder):
20 | folder = folder.replace('/tmp_grey', '')
21 | folder = folder.replace('/tmp', '')
22 | folder = folder.replace('/img', '')
23 | folder = folder.replace('/dep', '')
24 | file = os.path.join(folder, 'action.txt')
25 | if not os.path.exists(file):
26 | return np.asarray([])
27 | action_data = json.load(open(file, 'r'))[:]
28 | # print(action_data)
29 | res = []
30 | # for i, action in enumerate(action_data):
31 | # print(action)
32 | # res.append(
33 | # (
34 | # action[0],
35 | # action[1],
36 | # action[2][3] or action[2][4] or action[2][5] or action[2][6],
37 | # action[2][18] != 0
38 | # )
39 | # )
40 | # print([print(x[0]) for x in action_data[0:10]])
41 | res = [x[3][:2] for x in action_data]
42 | return np.abs(np.asarray(res))
43 |
44 |
45 | def read_ds_zip(path):
46 | dep, img = {}, {}
47 | tar = tarfile.open(path, "r:gz")
48 |
49 | for member in tar.getmembers():
50 | if '.jpg' not in member.name or not ('/dep/' in member.name or '/img/' in member.name):
51 | # print('skipped', member)
52 | continue
53 | collection = dep if '/dep/' in member.name else img
54 | index = int(member.name.split('/')[-1][1:-4])
55 | f = tar.extractfile(member)
56 | if f is not None:
57 | content = f.read()
58 | image = Image.open(io.BytesIO(content))
59 | collection[index] = np.array(image)
60 | assert len(img) == len(dep)
61 |
62 | shape = [len(img)] + list(img[index].shape)
63 | shape[-1] += 1
64 |
65 | dataset = np.zeros(shape, np.uint8)
66 | for i, k in enumerate(sorted(img)):
67 | dataset[i, ..., :-1] = img[k]
68 | dataset[i, ..., -1] = dep[k]
69 | return dataset#, shape[1:]
70 |
71 |
72 | def get_shape_zip(path):
73 | tar = tarfile.open(path, "r:gz")
74 | for member in tar.getmembers():
75 | if '.jpg' not in member.name or '/img/' not in member.name:
76 | continue
77 | f = tar.extractfile(member)
78 | content = f.read()
79 | image = Image.open(io.BytesIO(content))
80 | shape = list(np.array(image).shape)
81 | shape[-1] += 1
82 | return shape
83 |
84 |
85 | def rescale_ds(ds, min, max):
86 | ut.print_info('rescale call: (min: %s, max: %s) %d' % (str(min), str(max), len(ds)))
87 | if max is None:
88 | return np.asarray(ds) - np.min(ds)
89 | ds_min, ds_max = np.min(ds), np.max(ds)
90 | ds_gap = ds_max - ds_min
91 | scale_factor = (max - min) / ds_gap
92 | ds = np.asarray(ds) * scale_factor
93 | shift_factor = min - np.min(ds)
94 | ds += shift_factor
95 | return ds
96 |
97 |
98 | def get_input_name(input_folder):
99 | spliter = '/img/' if '/img/' in input_folder else '/dep/'
100 | main_part = input_folder.split(spliter)[0]
101 | name = main_part.split('/')[-1]
102 | name = name.replace('.tar.gz', '')
103 | ut.print_info('input folder: %s -> %s' % (input_folder.split('/'), name))
104 | return name
105 |
106 |
107 | def permute_array(array, random_state=None):
108 | return permute_data((array,))[0]
109 |
110 |
111 | def permute_data(arrays, random_state=None):
112 | """Permute multiple numpy arrays with the same order."""
113 | if any(len(a) != len(arrays[0]) for a in arrays):
114 | raise ValueError('All arrays must be the same length.')
115 | if not random_state:
116 | random_state = np.random
117 | order = random_state.permutation(len(arrays[0]))
118 | return [a[order] for a in arrays]
119 |
120 |
121 | def apply_gaussian(images, sigma):
122 | if sigma == 0:
123 | return images
124 |
125 | res = images.copy()
126 | for i, image in enumerate(res):
127 | for channel in range(image.shape[-1]):
128 | image[:, :, channel] = filters.gaussian_filter(image[:, :, channel], sigma)
129 | return res
130 |
131 |
132 | def permute_array_in_series(array, series_length, allow_shift=True):
133 | res, permutation = permute_data_in_series((array,), series_length)
134 | return res[0], permutation
135 |
136 |
137 | def permute_data_in_series(arrays, series_length, allow_shift=True):
138 | shift_possibilities = len(arrays[0]) % series_length
139 | series_count = int(len(arrays[0]) / series_length)
140 |
141 | shift = 0
142 | if allow_shift:
143 | if shift_possibilities == 0:
144 | shift_possibilities += series_length
145 | series_count -= 1
146 | shift = np.random.randint(0, shift_possibilities+1, 1, dtype=np.int32)[0]
147 |
148 | series = np.arange(0, series_count * series_length)\
149 | .astype(np.int32)\
150 | .reshape((series_count, series_length))
151 |
152 | series = np.random.permutation(series)
153 | data_permutation = series.reshape((series_count * series_length))
154 | data_permutation += shift
155 |
156 | remaining_elements = np.arange(0, len(arrays[0])).astype(np.int32)
157 | remaining_elements = np.delete(remaining_elements, data_permutation)
158 | data_permutation = np.concatenate((data_permutation, remaining_elements))
159 |
160 | # print('assert', len(arrays[0]), len(data_permutation))
161 | assert len(data_permutation) == len(arrays[0])
162 | return [a[data_permutation] for a in arrays], data_permutation
163 |
164 |
165 | def pad_set(set, batch_size):
166 | length = len(set)
167 | if length % batch_size == 0:
168 | return set
169 | padding_len = batch_size - length % batch_size
170 | if padding_len != 0:
171 | pass
172 | # ut.print_info('Non-zero padding: %d' % padding_len, color=31)
173 | # print('pad set', set.shape, select_random(padding_len, set=set).shape)
174 | return np.concatenate((set, select_random(padding_len, set=set)))
175 |
176 |
177 | def select_random(n, length=None, set=None):
178 | assert length is None or set is None
179 | length = length if set is None else len(set)
180 | select = np.random.permutation(np.arange(length, dtype=np.int))[:n]
181 | if set is None:
182 | return select
183 | else:
184 | return set[select]
185 |
186 | if __name__ == '__main__':
187 | print(read_ds_zip('/home/eugene/repo/data/tmp/romb8.5.6.tar.gz').shape)
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import utils as ut
4 | import sys
5 | import os
6 | import sklearn.metrics as m
7 | import visualization as vis
8 | import matplotlib.pyplot as plt
9 | import matplotlib.ticker as ticker
10 |
11 |
12 |
13 | def get_evaluation(path):
14 | f = ut.get_latest_file(path, filter=r'.*.npy$')
15 | print(f)
16 | dict = np.load(f).item()
17 | return dict
18 |
19 |
20 | def distance(ref, pred):
21 | error = ref - pred
22 | return l2(error)
23 |
24 |
25 | def l2(error):
26 | error = error ** 2
27 | error = np.sum(error, axis=1)
28 | return np.sqrt(error)
29 |
30 |
31 | def distance_improvement(prediction_dist, naive_dist):
32 | pred_mean, naive_mean = np.mean(prediction_dist), np.mean(naive_dist)
33 | improvement = naive_mean-pred_mean if naive_mean-pred_mean > 0 else 0
34 | improvement = improvement / naive_mean
35 | # print('predictive error(naive): %.9f (%.9f) -> %.2f%%'
36 | # % (pred_mean, naive_mean, improvement*100))
37 | return improvement, pred_mean, naive_mean
38 |
39 |
40 | def distance_binary_improvement(prediction_dist, naive_dist):
41 | pairwise = naive_dist - prediction_dist
42 | pairwise[pairwise > 0] = 1
43 | pairwise[pairwise < 0] = 0
44 | fraction = np.mean(pairwise)
45 | # print('Pairwise error improved for : %f%%' % (fraction*100), pairwise)
46 | return fraction
47 |
48 |
49 | def nn_metric(point_array):
50 | """
51 | For every frame how often previous frame t-1 and next frame t+1 are within top-2 nearest neighbours
52 | :param point_array: point array
53 | :return:
54 | """
55 | total = 0
56 | for i in range(1, len(point_array)-1):
57 | x = point_array - point_array[i]
58 | d = l2(x)
59 | indexes = np.argsort(d)[:3]
60 | # assert i in indexes or d[indexes[0]][0] == 0
61 | if i-1 in indexes:
62 | total += 1
63 | if i+1 in indexes:
64 | total += 1
65 | # print(i, i-1 in indexes, i+1 in indexes, indexes)
66 | metric = total/(len(point_array) - 2) / 2
67 | # print('NN metric: %.7f%%' % (metric * 100))
68 | return metric
69 |
70 |
71 | def nn_metric_pred(prediction, target):
72 | """
73 | For every frame how often previous frame t-1 and next frame t+1 are within top-2 nearest neighbours
74 | :param target: point array
75 | :return:
76 | """
77 | total = 0
78 | for i in range(len(target)):
79 | x = target - prediction[i]
80 | d = l2(x)
81 | index = np.argmin(d)
82 | # print(index, i, d)
83 | if i == index:
84 | total += 1
85 | metric = total/len(target)
86 | # print('NN metric for preditcion: %.7f%%' % (metric * 100))
87 | return metric
88 |
89 |
90 | def test_nn():
91 | enc = np.arange(0, 100).reshape((100, 1))
92 | # enc.transpose()
93 | # print(enc.shape)
94 | # print(nn_metric(enc))
95 | assert nn_metric(enc) == 1.
96 |
97 |
98 | def test_nn_pred():
99 | enc = np.arange(0, 100).reshape((100, 1))
100 | # enc.transpose()
101 | # print(enc.shape)
102 | # print(nn_metric_pred(enc, enc+0.2))
103 | assert nn_metric_pred(enc, enc) == 1.
104 |
105 |
106 | def reco_error(x, y):
107 | delta = x-y
108 | error = m.mean_squared_error(x.flatten(), y.flatten())
109 | return error
110 |
111 |
112 | def print_folder_metrics(path):
113 | eval = get_evaluation(path)
114 | enc = eval['enc']
115 | pred = enc[1:-1]*2 - enc[0:-2]
116 | ref = enc[2:]
117 |
118 | # print(enc[0], enc[1], pred[0], enc[2])
119 |
120 | pred_to_target_dist, next_dist = distance(ref, pred), distance(enc[1:-1], enc[2:])
121 | pl2, pred_d, naiv_d = distance_improvement(pred_to_target_dist, next_dist)
122 | pb = distance_binary_improvement(pred_to_target_dist, next_dist)
123 | pnn = nn_metric(enc)
124 | pnnp = (nn_metric_pred(pred, ref)*100)
125 | lreco = reco_error(eval['rec'], eval['blu'])
126 | info = '%.3f & %.3f & %.3f & %.2f' % (pl2, pb, pnn, lreco)
127 | print(info)
128 | print('pimp:%f(%f/%f) & pb:%f & pnn:%f' % (pl2, pred_d, naiv_d, pb, pnn), 'pnnp: %f' % pnnp)
129 |
130 | return info
131 |
132 |
133 | def plot_single_cross_section_3d(data, select, subplot):
134 | data = data[:, select]
135 | # subplot.scatter(data[:, 0], data[:, 1], s=20, lw=0, edgecolors='none', alpha=1.0,
136 | # subplot.plot(data[:, 0], data[:, 1], data[:, 2], color='black', lw=1, alpha=0.4)
137 |
138 | d = data
139 | # subplot.plot(d[[-1, 0], 0], d[[-1, 0], 1], d[[-1, 0], 2], lw=1, alpha=0.8, color='red')
140 | # subplot.scatter(d[[-1, 0], 0], d[[-1, 0], 1], d[[-1, 0], 2], lw=10, alpha=0.3, marker=".", color='b')
141 | d = data
142 | subplot.scatter(d[:, 0], d[:, 1], d[:, 2], s=4, alpha=1.0, lw=0.5,
143 | c=vis._build_radial_colors(len(d)),
144 | marker=".",
145 | cmap=plt.cm.hsv)
146 | subplot.plot(data[:, 0], data[:, 1], data[:, 2], color='black', lw=0.2, alpha=0.9)
147 |
148 | subplot.set_xlim([-0.01, 1.01])
149 | subplot.set_ylim([-0.01, 1.01])
150 | subplot.set_zlim([-0.01, 1.01])
151 | ticks = []
152 | subplot.xaxis.set_ticks(ticks)
153 | subplot.yaxis.set_ticks(ticks)
154 | subplot.zaxis.set_ticks(ticks)
155 | subplot.xaxis.set_major_formatter(ticker.FormatStrFormatter('%1.0f'))
156 | subplot.yaxis.set_major_formatter(ticker.FormatStrFormatter('%1.0f'))
157 |
158 |
159 | def plot_single_cross_section_line(data, select, subplot):
160 | data = data[:, select]
161 | # subplot.scatter(data[:, 0], data[:, 1], s=20, lw=0, edgecolors='none', alpha=1.0,
162 | # subplot.plot(data[:, 0], data[:, 1], data[:, 2], color='black', lw=1, alpha=0.4)
163 |
164 | d = data
165 | # subplot.plot(d[[-1, 0], 0], d[[-1, 0], 1], d[[-1, 0], 2], lw=1, alpha=0.8, color='red')
166 | # subplot.scatter(d[[-1, 0], 0], d[[-1, 0], 1], d[[-1, 0], 2], lw=10, alpha=0.3, marker=".", color='b')
167 | d = data
168 | subplot.plot(data[:, 0], data[:, 1], data[:, 2], color='black', lw=1, alpha=0.4)
169 |
170 | subplot.set_xlim([-0.01, 1.01])
171 | subplot.set_ylim([-0.01, 1.01])
172 | subplot.set_zlim([-0.01, 1.01])
173 | ticks = []
174 | subplot.xaxis.set_ticks(ticks)
175 | subplot.yaxis.set_ticks(ticks)
176 | subplot.zaxis.set_ticks(ticks)
177 | subplot.xaxis.set_major_formatter(ticker.FormatStrFormatter('%1.0f'))
178 | subplot.yaxis.set_major_formatter(ticker.FormatStrFormatter('%1.0f'))
179 |
180 |
181 | if __name__ == '__main__':
182 | path = os.getcwd()
183 |
184 | for _, paths, _ in os.walk(path):
185 | print('dirs', paths)
186 | break
187 |
188 | if len(paths) == 0:
189 | print_folder_metrics(path)
190 |
191 | eval = get_evaluation(path)
192 | enc = eval['enc']
193 |
194 | fig = vis.get_figure(shape=[1800, 900, 3])
195 | # ax =
196 | plot_single_cross_section_3d(enc, [0, 1, 2], plt.subplot(121, projection='3d'))
197 | plot_single_cross_section_line(enc, [0, 1, 2], plt.subplot(122, projection='3d'))
198 | plt.tight_layout()
199 | plt.show()
200 | else:
201 | res = []
202 | for d in paths + [path]:
203 | print(d)
204 | c_path = os.path.join(path, d)
205 | info = print_folder_metrics(c_path)
206 | res.append('\n%30s:\n%s' % (d, info))
207 | res = sorted(res)
208 | print('\n'.join(res))
209 | print(len(paths), len(res))
210 | exit(0)
211 |
212 |
213 | if 'TensorFlow_DCIGN' in os.getcwd().split('/')[-1]:
214 | # path = '/home/eugene/repo/TensorFlow_DCIGN/tmp/noise.f20_f4__i_grid03.14.c'
215 | # path = '/mnt/code/vd/TensorFlow_DCIGN/tmp/pred.f101_f3__i_romb8.5.6'
216 | path = '/media/eugene/back up/VD_backup/tmp_epoch20_final/pred.16c3s2_32c3s2_32c3s2_23c3_f3__i_romb8.5.6_'
217 | # path = '/media/eugene/back up/VD_backup/tmp_epoch19_inputs/pred.16c3s2_32c3s2_32c3s2_16c3_f100_f3__i_grid.28.gh.360'
218 |
219 |
220 |
--------------------------------------------------------------------------------
/model_interpreter.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import tensorflow.contrib.slim as slim
4 | import network_utils as nut
5 | import utils as ut
6 | import re
7 | import os
8 | from Bunch import Bunch
9 |
10 | INPUT = 'input'
11 | FC = 'fully_connected'
12 | CONV = 'convolutional'
13 | POOL = 'max_pooling'
14 | POOL_ARG = 'maxpool_with_args'
15 | DO = 'dropout'
16 | LOSS = 'loss'
17 |
18 | activation_voc = {
19 | 's': tf.nn.sigmoid,
20 | 'r': tf.nn.relu,
21 | 't': tf.nn.tanh,
22 | 'i': None
23 | }
24 |
25 | CONFIG_COLOR = 30
26 | PADDING = 'SAME'
27 |
28 |
29 | def clean_unpooling_masks(layer_config):
30 | """
31 | Cleans POOL_ARG configs positional information
32 |
33 | :param layer_config: list of layer descriptors
34 | :return: dictinary of ARGMAX_POOL layer name and corresponding mask source
35 | """
36 | mask_list = [cfg.argmax for cfg in layer_config if cfg.type == POOL_ARG]
37 | for cfg in layer_config:
38 | if cfg.type == POOL_ARG:
39 | cfg.argmax = None
40 | return mask_list
41 |
42 |
43 | def build_autoencoder(input, layer_config):
44 | reuse_model = isinstance(layer_config, list)
45 | if not reuse_model:
46 | layer_config
47 | layer_config = layer_config.replace('_', '-').split('-')
48 | layer_config = [parse_input(input)] + [parse(x) for x in layer_config]
49 | if not reuse_model:
50 | ut.print_info('Model config:', color=CONFIG_COLOR)
51 | enc = build_encoder(input, layer_config, reuse=reuse_model)
52 | dec = build_decoder(enc, layer_config, reuse=reuse_model)
53 | mask_list = clean_unpooling_masks(layer_config)
54 | losses = build_losses(layer_config)
55 | return Bunch(
56 | encode=enc,
57 | decode=dec,
58 | losses=losses,
59 | config=layer_config,
60 | mask_list=mask_list)
61 |
62 |
63 | def build_encoder(net, layer_config, i=1, reuse=False):
64 | if i == len(layer_config):
65 | return net
66 |
67 | cfg = layer_config[i]
68 | cfg.shape = net.get_shape().as_list()
69 | name = cfg.enc_op_name if reuse else None
70 | cfg.ein = net
71 | if cfg.type == FC:
72 | if len(cfg.shape) > 2:
73 | net = slim.flatten(net)
74 | net = slim.fully_connected(net, cfg.size, activation_fn=cfg.activation,
75 | scope=name, reuse=reuse)
76 | elif cfg.type == CONV:
77 | net = slim.conv2d(net, cfg.size, [cfg.kernel, cfg.kernel], stride=cfg.stride,
78 | activation_fn=cfg.activation, padding=PADDING,
79 | scope=name, reuse=reuse)
80 | elif cfg.type == POOL_ARG:
81 | net, cfg.argmax = nut.max_pool_with_argmax(net, cfg.kernel)
82 | # if not reuse:
83 | # mask = nut.fake_arg_max_of_max_pool(cfg.shape, cfg.kernel)
84 | # cfg.argmax_dummy = tf.constant(mask.flatten(), shape=mask.shape)
85 | elif cfg.type == POOL:
86 | net = slim.max_pool2d(net, kernel_size=[cfg.kernel, cfg.kernel], stride=cfg.kernel)
87 | elif cfg.type == DO:
88 | net = tf.nn.dropout(net, keep_prob=cfg.keep_prob)
89 | elif cfg.type == LOSS:
90 | cfg.arg1 = net
91 | elif cfg.type == INPUT:
92 | assert False
93 |
94 | if not reuse:
95 | cfg.enc_op_name = net.name.split('/')[0]
96 | if not reuse:
97 | ut.print_info('\rencoder_%d\t%s\t%s' % (i, str(net), cfg.enc_op_name), color=CONFIG_COLOR)
98 | return build_encoder(net, layer_config, i + 1, reuse=reuse)
99 |
100 |
101 | def build_decoder(net, layer_config, i=None, reuse=False, masks=None):
102 | i = i if i is not None else len(layer_config) - 1
103 |
104 | cfg = layer_config[i]
105 | name = cfg.dec_op_name if reuse else None
106 | if len(layer_config) > i + 1:
107 | if len(layer_config[i + 1].shape) != len(net.get_shape().as_list()):
108 | net = tf.reshape(net, layer_config[i + 1].shape)
109 |
110 | if i < 0 or layer_config[i].type == INPUT:
111 | return net
112 |
113 | if cfg.type == FC:
114 | net = slim.fully_connected(net, int(np.prod(cfg.shape[1:])), scope=name,
115 | activation_fn=cfg.activation, reuse=reuse)
116 | elif cfg.type == CONV:
117 | net = slim.conv2d_transpose(net, cfg.shape[-1], [cfg.kernel, cfg.kernel], stride=cfg.stride,
118 | activation_fn=cfg.activation, padding=PADDING,
119 | scope=name, reuse=reuse)
120 | elif cfg.type == POOL_ARG:
121 | if cfg.argmax is not None or masks is not None:
122 | mask = cfg.argmax if cfg.argmax is not None else masks.pop()
123 | net = nut.unpool(net, mask=mask, stride=cfg.kernel)
124 | else:
125 | net = nut.upsample(net, stride=cfg.kernel, mode='COPY')
126 | elif cfg.type == POOL:
127 | net = nut.upsample(net, cfg.kernel)
128 | elif cfg.type == DO:
129 | pass
130 | elif cfg.type == LOSS:
131 | cfg.arg2 = net
132 | elif cfg.type == INPUT:
133 | assert False
134 | if not reuse:
135 | cfg.dec_op_name = net.name.split('/')[0]
136 | if not reuse:
137 | ut.print_info('\rdecoder_%d \t%s' % (i, str(net)), color=CONFIG_COLOR)
138 | cfg.dout = net
139 | return build_decoder(net, layer_config, i - 1, reuse=reuse, masks=masks)
140 |
141 |
142 | def build_stacked_losses(model):
143 | losses = []
144 | for i, cfg in enumerate(model.config):
145 | if cfg.type in [FC, CONV]:
146 | input = tf.stop_gradient(cfg.ein, name='stacked_breakpoint_%d' % i)
147 | net = build_encoder(input, [None, model.config[i]], reuse=True)
148 | net = build_decoder(net, [model.config[i]], reuse=True)
149 | losses.append(l2_loss(input, net, name='stacked_loss_%d' % i))
150 | model.stacked_losses = losses
151 |
152 |
153 | def build_losses(layer_config):
154 | return []
155 |
156 |
157 | def l2_loss(arg1, arg2, alpha=1.0, name='reco_loss'):
158 | with tf.name_scope(name):
159 | loss = tf.nn.l2_loss(arg1 - arg2)
160 | return alpha * loss
161 |
162 |
163 | def get_activation(descriptor):
164 | if 'c' not in descriptor and 'f' not in descriptor:
165 | return None
166 | activation = tf.nn.relu if 'c' in descriptor else tf.nn.sigmoid
167 | act_descriptor = re.search('[r|s|i|t]&', descriptor)
168 | if act_descriptor is None:
169 | return activation
170 | act_descriptor = act_descriptor.group(0)
171 | return activation_voc[act_descriptor]
172 |
173 |
174 | def _get_cfg_dummy():
175 | return Bunch(enc_op_name=None, dec_op_name=None)
176 |
177 |
178 | def parse(descriptor):
179 | item = _get_cfg_dummy()
180 |
181 | match = re.match(r'^((\d+c\d+(s\d+)?[r|s|i|t]?)'
182 | r'|(f\d+[r|s|i|t]?)'
183 | r'|(d0?\.?[\d+]?)'
184 | r'|(d0?\.?[\d+]?)'
185 | r'|(p\d+)'
186 | r'|(ap\d+))$', descriptor)
187 | assert match is not None, 'Check your writing: %s (f10i-3c64r-d0.1-p2-ap2)' % descriptor
188 |
189 |
190 | if 'f' in descriptor:
191 | item.type = FC
192 | item.activation = get_activation(descriptor)
193 | item.size = int(re.search('f\d+', descriptor).group(0)[1:])
194 | elif 'c' in descriptor:
195 | item.type = CONV
196 | item.activation = get_activation(descriptor)
197 | item.kernel = int(re.search('c\d+', descriptor).group(0)[1:])
198 | stride = re.search('s\d+', descriptor)
199 | item.stride = int(stride.group(0)[1:]) if stride is not None else 1
200 | item.size = int(re.search('\d+c', descriptor).group(0)[:-1])
201 | elif 'd' in descriptor:
202 | item.type = DO
203 | item.keep_prob = float(descriptor[1:])
204 | elif 'ap' in descriptor:
205 | item.type = POOL_ARG
206 | item.kernel = int(descriptor[2:])
207 | elif 'p' in descriptor:
208 | item.type = POOL
209 | item.kernel = int(descriptor[1:])
210 | elif 'l' in descriptor:
211 | item.type = LOSS
212 | item.loss_type = 'l2'
213 | item.alpha = float(descriptor.split('l')[0])
214 | else:
215 | print('What is "%s"? Check your writing 16c2i-7c3r-p3-0.01l-f10t-d0.3' % descriptor)
216 | assert False
217 | return item
218 |
219 |
220 | def parse_input(input):
221 | item = _get_cfg_dummy()
222 | item.type = INPUT
223 | item.shape = input.get_shape().as_list()
224 | item.dout = input
225 | return item
226 |
227 |
228 | def _log_graph():
229 | path = '/tmp/interpreter'
230 | with tf.Session() as sess:
231 | tf.global_variables_initializer()
232 | tf.summary.FileWriter(path, sess.graph)
233 | ut.print_color(os.path.abspath(path), color=33)
--------------------------------------------------------------------------------
/model_interpreter_test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import tensorflow.contrib.slim as slim
4 | import network_utils as nut
5 | import utils as ut
6 | import re
7 | import os
8 | from Bunch import Bunch
9 | from model_interpreter import *
10 |
11 |
12 | def _log_graph():
13 | path = '/tmp/interpreter'
14 | with tf.Session() as sess:
15 | tf.global_variables_initializer()
16 | tf.summary.FileWriter(path, sess.graph)
17 | ut.print_color(os.path.abspath(path), color=33)
18 |
19 |
20 | def _test_parameter_reuse_conv():
21 | input = tf.placeholder(tf.float32, (2, 120, 120, 4), name='input')
22 | model = build_autoencoder(input, '8c3s2-16c3s2-32c3s2-1c3-f4')
23 |
24 | input = tf.placeholder(tf.float32, (2, 120, 120, 4), name='input')
25 | model = build_autoencoder(input, '8c3s2-16c3s2-32c3s2-1c3-f4')
26 | l2_loss(input, model.decode)
27 |
28 | input = tf.placeholder(tf.float32, (2, 120, 120, 4), name='input')
29 | model = build_autoencoder(input, model.config)
30 | l2_loss(input, model.decode)
31 |
32 |
33 | def _test_parameter_reuse_decoder():
34 | input = tf.placeholder(tf.float32, (2, 120, 120, 4), name='input')
35 | model = build_autoencoder(input, '8c3s2-16c3s2-32c3s2-16c3-f4')
36 |
37 | input = tf.placeholder(tf.float32, (2, 120, 120, 4), name='input')
38 | model = build_autoencoder(input, '8c3s2-16c3s2-32c3s2-16c3-f4')
39 |
40 | input_enc = tf.placeholder(tf.float32, (2, 120, 120, 4), name='input_encoder')
41 | encoder = build_encoder(input_enc, model.config, reuse=True)
42 |
43 | input_dec = tf.placeholder(tf.float32, (2, 4), name='input_decoder')
44 | decoder = build_decoder(input_dec, model.config, reuse=True)
45 |
46 | l2_loss(input, model.decode)
47 | l2_loss(encoder, model.encode)
48 | l2_loss(decoder, model.decode)
49 |
50 |
51 | def _test_armgax_ae():
52 | input = tf.placeholder(tf.float32, (2, 120, 120, 4), name='input')
53 | model = build_autoencoder(input, '8c3-ap2-16c3-ap2-32c3-ap2-16c3-f4')
54 |
55 | input = tf.placeholder(tf.float32, (2, 120, 120, 4), name='input')
56 | model = build_autoencoder(input, '8c3-ap2-16c3-ap2-32c3-ap2-16c3-f4')
57 |
58 | input_enc = tf.placeholder(tf.float32, (2, 120, 120, 4), name='input_encoder')
59 | encoder = build_encoder(input_enc, model.config, reuse=True)
60 |
61 | input_dec = tf.placeholder(tf.float32, (2, 4), name='input_decoder')
62 | decoder = build_decoder(input_dec, model.config, reuse=True)
63 |
64 | l2_loss(input, model.decode)
65 | l2_loss(encoder, model.encode)
66 | l2_loss(decoder, model.decode)
67 |
68 |
69 | def _test_multiple_decoders_unpool_wiring():
70 | input = tf.placeholder(tf.float32, (2, 16, 16, 3), name='input')
71 | model1 = build_autoencoder(input, '8c3-ap2-f4')
72 | model2 = build_autoencoder(input, model1.config)
73 | enc_1 = tf.placeholder(tf.float32, (2, 4), name='enc_1')
74 | enc_2 = tf.placeholder(tf.float32, (2, 4), name='enc_2')
75 | decoder1 = build_decoder(enc_1, model1.config, reuse=True, masks=model1.mask_list)
76 | decoder2 = build_decoder(enc_2, model2.config, reuse=True, masks=model2.mask_list)
77 |
78 |
79 | def _visualize_models():
80 | input = tf.placeholder(tf.float32, (128, 160, 120, 4), name='input_fc')
81 | model = build_autoencoder(input, 'f100-f3')
82 | loss= l2_loss(input, model.decode, name='Loss_reconstruction_FC')
83 |
84 | input = tf.placeholder(tf.float32, (128, 160, 120, 4), name='input_conv')
85 | model = build_autoencoder(input, '16c3s2-32c3s2-32c3s2-16c3-f3')
86 | loss= l2_loss(input, model.decode, name='Loss_reconstruction_conv')
87 |
88 | input = tf.placeholder(tf.float32, (128, 160, 120, 4), name='input_wwae')
89 | model = build_autoencoder(input, '16c3-ap2-32c3-ap2-16c3-f3')
90 | loss= l2_loss(input, model.decode, name='Loss_reconstruction_WWAE')
91 |
92 |
93 | def _test_stacked_ae():
94 | input = tf.placeholder(tf.float32, (2, 16, 16, 3), name='input')
95 | model1 = build_autoencoder(input, 'f500-f100-f5')
96 | losses = build_stacked_losses(model1)
97 |
98 |
99 | if __name__ == '__main__':
100 | # print(re.match('\d+c\d+(s\d+)?[r|s|i|t]?', '8c3s2'))
101 | # model = build_autoencoder(tf.placeholder(tf.float32, (2, 16, 16, 3), name='input'), '8c3s2-16c3s2-30c3s2-16c3-f4')
102 | # _test_multiple_decoders_unpool_wiring()
103 | # _visualize_models()
104 | x = _test_stacked_ae()
105 |
106 | # build_autoencoder(tf.placeholder(tf.float32, (2, 16, 16, 3), name='input'), '10c3-f100-f10')
107 | # _test_parameter_reuse_conv()
108 | # _test_parameter_reuse_decoder()
109 | # _test_armgax_ae()
110 | _log_graph()
111 |
--------------------------------------------------------------------------------
/network_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import tensorflow.contrib.slim as slim
4 | import tools.checkpoint_utils as ch_utils
5 | import scipy.stats as st
6 | import inspect
7 |
8 | # POOLING
9 |
10 |
11 | def max_pool_with_argmax(net, stride):
12 | """
13 | Tensorflow default implementation does not provide gradient operation on max_pool_with_argmax
14 | Therefore, we use max_pool_with_argmax to extract mask and
15 | plain max_pool for, eeem... max_pooling.
16 | """
17 | with tf.name_scope('MaxPoolArgMax'):
18 | _, mask = tf.nn.max_pool_with_argmax(
19 | net,
20 | ksize=[1, stride, stride, 1],
21 | strides=[1, stride, stride, 1],
22 | padding='SAME')
23 | mask = tf.stop_gradient(mask)
24 | net = slim.max_pool2d(net, kernel_size=[stride, stride], stride=stride)
25 | return net, mask
26 |
27 |
28 | def fake_arg_max_of_max_pool(shape, stride=2):
29 | assert shape[1] % stride == 0 and shape[2] % stride == 0, \
30 | 'Smart padding is not supported. Indexes %s are not multiple of stride:%d' % (str(shape[1:3]), stride)
31 | mask = np.arange(np.prod(shape[1:]))
32 | mask = mask.reshape(shape[1:])
33 | mask = mask[::stride, ::stride, :]
34 | mask = np.tile(mask, (shape[0], 1, 1, 1))
35 | return mask
36 |
37 |
38 | # Thank you, @https://github.com/Pepslee
39 | def unpool(net, mask, stride=2):
40 | assert mask is not None
41 | with tf.name_scope('UnPool2D'):
42 | ksize = [1, stride, stride, 1]
43 | input_shape = net.get_shape().as_list()
44 | # calculation new shape
45 | output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3])
46 | # calculation indices for batch, height, width and feature maps
47 | one_like_mask = tf.ones_like(mask)
48 | batch_range = tf.reshape(tf.range(output_shape[0], dtype=tf.int64), shape=[input_shape[0], 1, 1, 1])
49 | b = one_like_mask * batch_range
50 | y = mask // (output_shape[2] * output_shape[3])
51 | x = mask % (output_shape[2] * output_shape[3]) // output_shape[3]
52 | feature_range = tf.range(output_shape[3], dtype=tf.int64)
53 | f = one_like_mask * feature_range
54 | # transpose indices & reshape update values to one dimension
55 | updates_size = tf.size(net)
56 | indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size]))
57 | values = tf.reshape(net, [updates_size])
58 | ret = tf.scatter_nd(indices, values, output_shape)
59 | return ret
60 |
61 |
62 | def upsample(net, stride, mode='ZEROS'):
63 | """
64 | Imitate reverse operation of Max-Pooling by either placing original max values
65 | into a fixed postion of upsampled cell:
66 | [0.9] =>[[.9, 0], (stride=2)
67 | [ 0, 0]]
68 | or copying the value into each cell:
69 | [0.9] =>[[.9, .9], (stride=2)
70 | [ .9, .9]]
71 |
72 | :param net: 4D input tensor with [batch_size, width, heights, channels] axis
73 | :param stride:
74 | :param mode: string 'ZEROS' or 'COPY' indicating which value to use for undefined cells
75 | :return: 4D tensor of size [batch_size, width*stride, heights*stride, channels]
76 | """
77 | assert mode in ['COPY', 'ZEROS']
78 | with tf.name_scope('Upsampling'):
79 | net = _upsample_along_axis(net, 2, stride, mode=mode)
80 | net = _upsample_along_axis(net, 1, stride, mode=mode)
81 | return net
82 |
83 |
84 | def _upsample_along_axis(volume, axis, stride, mode='ZEROS'):
85 | shape = volume.get_shape().as_list()
86 |
87 | assert mode in ['COPY', 'ZEROS']
88 | assert 0 <= axis < len(shape)
89 |
90 | target_shape = shape[:]
91 | target_shape[axis] *= stride
92 |
93 | padding = tf.zeros(shape, dtype=volume.dtype) if mode == 'ZEROS' else volume
94 | parts = [volume] + [padding for _ in range(stride - 1)]
95 | assert list(inspect.signature(tf.concat).parameters.items())[1][0] == 'axis', 'Wrong TF version'
96 | volume = tf.concat(parts, min(axis+1, len(shape)-1))
97 | volume = tf.reshape(volume, target_shape)
98 | return volume
99 |
100 |
101 | # VARIABLES
102 |
103 |
104 | def print_model_info(trainable=False):
105 | if not trainable:
106 | for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
107 | print(v.name, v.get_shape())
108 | else:
109 | for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
110 | value = v.eval()
111 | print('TRAINABLE_VARIABLES', v.name, v.get_shape(), 'm:%.4f v:%.4f' % (value.mean(), value.std()))
112 |
113 |
114 | def list_checkpoint_vars(folder):
115 | # print(folder)
116 | f = ch_utils.list_variables(folder)
117 | # print(f)
118 | print('\n'.join(map(str, f)))
119 |
120 |
121 | def get_variable(checkpoint, name):
122 | var = ch_utils.load_variable(tf.train.latest_checkpoint(checkpoint), name)
123 | return var
124 |
125 |
126 | def scope_wrapper(scope_name):
127 | def scope_decorator(func):
128 | def func_wrapper(*args, **kwargs):
129 | with tf.name_scope(scope_name):
130 | return func(*args, **kwargs)
131 | return func_wrapper
132 | return scope_decorator
133 |
134 |
135 | # Gaussian blur
136 |
137 |
138 | def _build_gaussian_kernel(k_size, nsig, channels):
139 | interval = (2 * nsig + 1.) / k_size
140 | x = np.linspace(-nsig - interval / 2., nsig + interval / 2., k_size + 1)
141 | kern1d = np.diff(st.norm.cdf(x))
142 | kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
143 | kernel = kernel_raw / kernel_raw.sum()
144 | out_filter = np.array(kernel, dtype=np.float32)
145 | out_filter = out_filter.reshape((k_size, k_size, 1, 1))
146 | out_filter = np.repeat(out_filter, channels, axis=2)
147 | return out_filter
148 |
149 |
150 | def blur_gaussian(input, sigma, filter_size):
151 | num_channels = input.get_shape().as_list()[3]
152 | with tf.variable_scope('gaussian_filter'):
153 | kernel = _build_gaussian_kernel(filter_size, sigma, num_channels)
154 | kernel = tf.constant(kernel.flatten(), shape=kernel.shape, name='gauss_weight')
155 | output = tf.nn.depthwise_conv2d(input, kernel, [1, 1, 1, 1], padding='SAME')
156 | return output, kernel
157 |
158 |
159 | def nan_to_zero(tensor):
160 | return tf.where(tf.is_nan(tensor), tf.zeros_like(tensor), tensor)
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yselivonchyk/TensorFlow_DCIGN/ff8d85f3a7b7ca1e5c3f50ff003a1c09a70067cd/tools/__init__.py
--------------------------------------------------------------------------------
/tools/checkpoint_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tools to work with checkpoints."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import six
23 | import tensorflow as tf
24 | from tensorflow.python.ops import gen_io_ops
25 | from tensorflow.python.ops import state_ops
26 | from tensorflow.python.ops import variable_scope as vs
27 | from tensorflow.python.ops import variables
28 | from tensorflow.python.platform import gfile
29 | from tensorflow.python.platform import tf_logging as logging
30 | from tensorflow.python.training import saver
31 | from tensorflow.python.training import training as train
32 |
33 | __all__ = [
34 | "load_checkpoint",
35 | "load_variable",
36 | "list_variables",
37 | "init_from_checkpoint"]
38 |
39 |
40 | def _get_checkpoint_filename(filepattern):
41 | """Returns checkpoint filename given directory or specific filepattern."""
42 | import os
43 | if os.path.isdir(filepattern):
44 | print('nope', filepattern)
45 | res = tf.train.latest_checkpoint(filepattern)
46 | print('but here', res)
47 | return res
48 | return filepattern
49 |
50 |
51 | def load_checkpoint(filepattern):
52 | """Returns CheckpointReader for latest checkpoint.
53 |
54 | Args:
55 | filepattern: Directory with checkpoints file or path to checkpoint.
56 |
57 | Returns:
58 | `CheckpointReader` object.
59 |
60 | Raises:
61 | ValueError: if checkpoint_dir doesn't have 'checkpoint' file or checkpoints.
62 | """
63 | filename = _get_checkpoint_filename(filepattern)
64 | if filename is None:
65 | raise ValueError("Couldn't find 'checkpoint' file or checkpoints in "
66 | "given directory %s" % filepattern)
67 | return train.NewCheckpointReader(filename)
68 |
69 |
70 | def load_variable(checkpoint_dir, name):
71 | """Returns a Tensor with the contents of the given variable in the checkpoint.
72 |
73 | Args:
74 | checkpoint_dir: Directory with checkpoints file or path to checkpoint.
75 | name: Name of the tensor to return.
76 |
77 | Returns:
78 | `Tensor` object.
79 | """
80 | # TODO(b/29227106): Fix this in the right place and remove this.
81 | if name.endswith(":0"):
82 | name = name[:-2]
83 | reader = load_checkpoint(checkpoint_dir)
84 | return reader.get_tensor(name)
85 |
86 |
87 | def list_variables(checkpoint_dir):
88 | """Returns list of all variables in the latest checkpoint.
89 |
90 | Args:
91 | checkpoint_dir: Directory with checkpoints file or path to checkpoint.
92 |
93 | Returns:
94 | List of tuples `(name, shape)`.
95 | """
96 | reader = load_checkpoint(checkpoint_dir)
97 | variable_map = reader.get_variable_to_shape_map()
98 | names = sorted(variable_map.keys())
99 | result = []
100 | for name in names:
101 | result.append((name, variable_map[name]))
102 | return result
103 |
104 |
105 | # pylint: disable=protected-access
106 | # Currently variable_scope doesn't provide very good APIs to access
107 | # all variables under scope and retrieve and check existing scopes.
108 | # TODO(ipolosukhin): Refactor variable_scope module to provide nicer APIs.
109 |
110 |
111 | def _set_checkpoint_initializer(variable, file_pattern, tensor_name, slice_spec,
112 | name="checkpoint_initializer"):
113 | """Sets variable initializer to assign op form value in checkpoint's tensor.
114 |
115 | Args:
116 | variable: `Variable` object.
117 | file_pattern: string, where to load checkpoints from.
118 | tensor_name: Name of the `Tensor` to load from checkpoint reader.
119 | slice_spec: Slice specification for loading partitioned variables.
120 | name: Name of the operation.
121 | """
122 | base_type = variable.dtype.base_dtype
123 | restore_op = gen_io_ops._restore_slice(
124 | file_pattern,
125 | tensor_name,
126 | slice_spec,
127 | base_type,
128 | preferred_shard=-1,
129 | name=name)
130 | variable._initializer_op = state_ops.assign(variable, restore_op)
131 |
132 |
133 | def _set_variable_or_list_initializer(variable_or_list, file_pattern,
134 | tensor_name):
135 | if isinstance(variable_or_list, (list, tuple)):
136 | # A set of slices.
137 | slice_name = None
138 | for v in variable_or_list:
139 | if slice_name is None:
140 | slice_name = v._save_slice_info.full_name
141 | elif slice_name != v._save_slice_info.full_name:
142 | raise ValueError("Slices must all be from the same tensor: %s != %s" %
143 | (slice_name, v._save_slice_info.full_name))
144 | _set_checkpoint_initializer(v, file_pattern, tensor_name,
145 | v._save_slice_info.spec)
146 | else:
147 | _set_checkpoint_initializer(variable_or_list, file_pattern, tensor_name, "")
148 |
149 |
150 | def init_from_checkpoint(checkpoint_dir, assignment_map):
151 | """Using assingment map initializes current variables with loaded tensors.
152 |
153 | Note: This overrides default initialization ops of specified variables and
154 | redefines dtype.
155 |
156 | Assignment map supports following syntax:
157 | `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in
158 | current `scope_name` from `checkpoint_scope_name` with matching variable
159 | names.
160 | `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` -
161 | will initalize `scope_name/variable_name` variable
162 | from `checkpoint_scope_name/some_other_variable`.
163 | `'scope_variable_name': variable` - will initialize given `tf.Variable`
164 | object with variable from the checkpoint.
165 | `'scope_variable_name': list(variable)` - will initialize list of
166 | partitioned variables with variable from the checkpoint.
167 | `'scope_name/': '/'` - will load all variables in current `scope_name` from
168 | checkpoint's root (e.g. no scope).
169 |
170 | Supports loading into partitioned variables, which are represented as
171 | '/part_'.
172 |
173 | Example:
174 | ```python
175 | # Create variables.
176 | with tf.variable_scope('test'):
177 | m = tf.get_variable('my_var')
178 | with tf.variable_scope('test2'):
179 | var2 = tf.get_variable('my_var')
180 | var3 = tf.get_variable(name="my1", shape=[100, 100],
181 | partitioner=lambda shape, dtype: [5, 1])
182 | ...
183 | # Specify which variables to intialize from checkpoint.
184 | init_from_checkpoint(checkpoint_dir, {
185 | 'some_var': 'test/my_var',
186 | 'some_scope/': 'test2/'})
187 | ...
188 | # Or use `Variable` objects to identify what to initialize.
189 | init_from_checkpoint(checkpoint_dir, {
190 | 'some_scope/var2': var2,
191 | })
192 | # Initialize partitioned variables
193 | init_from_checkpoint(checkpoint_dir, {
194 | 'some_var_from_ckpt': 'part_var',
195 | })
196 | # Or specifying the list of `Variable` objects.
197 | init_from_checkpoint(checkpoint_dir, {
198 | 'some_var_from_ckpt': var3._get_variable_list(),
199 | })
200 | ...
201 | # Initialize variables as usual.
202 | session.run(tf.get_all_variables())
203 | ```
204 |
205 | Args:
206 | checkpoint_dir: Directory with checkpoints file or path to checkpoint.
207 | assignment_map: Dict, where keys are names of the variables in the
208 | checkpoint and values are current variables or names of current variables
209 | (in default graph).
210 |
211 | Raises:
212 | tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
213 | ValueError: If missing variables in current graph.
214 | """
215 | filepattern = _get_checkpoint_filename(checkpoint_dir)
216 | reader = load_checkpoint(checkpoint_dir)
217 | variable_map = reader.get_variable_to_shape_map()
218 | for tensor_name_in_ckpt, current_var_or_name in six.iteritems(assignment_map):
219 | var = None
220 | # Check if this is Variable object or list of Variable objects (in case of
221 | # partitioned variables).
222 | is_var = lambda x: isinstance(x, variables.Variable)
223 | if is_var(current_var_or_name) or (
224 | isinstance(current_var_or_name, list)
225 | and all(is_var(v) for v in current_var_or_name)):
226 | var = current_var_or_name
227 | else:
228 | var_scope = vs._get_default_variable_store()
229 | # Check if this variable is in var_store.
230 | var = var_scope._vars.get(current_var_or_name, None)
231 | # Also check if variable is partitioned as list.
232 | if var is None:
233 | if current_var_or_name + "/part_0" in var_scope._vars:
234 | var = []
235 | i = 0
236 | while current_var_or_name + "/part_%d" % i in var_scope._vars:
237 | var.append(var_scope._vars[current_var_or_name + "/part_%d" % i])
238 | i += 1
239 | if var is not None:
240 | # If 1 to 1 mapping was provided, find variable in the checkpoint.
241 | if tensor_name_in_ckpt not in variable_map:
242 | raise ValueError("Tensor %s is not found in %s checkpoint" % (
243 | tensor_name_in_ckpt, checkpoint_dir
244 | ))
245 | if is_var(var):
246 | # Additional at-call-time checks.
247 | if not var.get_shape().is_compatible_with(
248 | variable_map[tensor_name_in_ckpt]):
249 | raise ValueError(
250 | "Shape of variable %s (%s) doesn't match with shape of "
251 | "tensor %s (%s) from checkpoint reader." % (
252 | var.name, str(var.get_shape()),
253 | tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt])
254 | ))
255 | var_name = var.name
256 | else:
257 | var_name = ",".join([v.name for v in var])
258 | _set_variable_or_list_initializer(var, filepattern, tensor_name_in_ckpt)
259 | logging.info("Initialize variable %s from checkpoint %s with %s" % (
260 | var_name, checkpoint_dir, tensor_name_in_ckpt
261 | ))
262 | else:
263 | scopes = ""
264 | # TODO(vihanjain): Support list of 'current_var_or_name' here.
265 | if "/" in current_var_or_name:
266 | scopes = current_var_or_name[:current_var_or_name.rindex("/")]
267 | if not tensor_name_in_ckpt.endswith("/"):
268 | raise ValueError(
269 | "Assignment map with scope only name {} should map to scope only "
270 | "{}. Should be 'scope/': 'other_scope/'.".format(
271 | scopes, tensor_name_in_ckpt))
272 | # If scope to scope mapping was provided, find all variables in the scope.
273 | for var_name in var_scope._vars:
274 | if var_name.startswith(scopes):
275 | # Lookup name with specified prefix and suffix from current variable.
276 | # If tensor_name given is '/' (root), don't use it for full name.
277 | if tensor_name_in_ckpt != "/":
278 | full_tensor_name = tensor_name_in_ckpt + var_name[len(scopes) + 1:]
279 | else:
280 | full_tensor_name = var_name[len(scopes) + 1:]
281 | if full_tensor_name not in variable_map:
282 | raise ValueError(
283 | "Tensor %s (%s in %s) is not found in %s checkpoint" % (
284 | full_tensor_name, var_name[len(scopes) + 1:],
285 | tensor_name_in_ckpt, checkpoint_dir
286 | ))
287 | var = var_scope._vars[var_name]
288 | _set_variable_or_list_initializer(var, filepattern, full_tensor_name)
289 | logging.info("Initialize variable %s from checkpoint %s with %s" % (
290 | var_name, checkpoint_dir, full_tensor_name
291 | ))
292 | # pylint: enable=protected-access
--------------------------------------------------------------------------------
/tools/freeze_graph.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Converts checkpoint variables into Const ops in a standalone GraphDef file.
16 |
17 | This script is designed to take a GraphDef proto, a SaverDef proto, and a set of
18 | variable values stored in a checkpoint file, and output a GraphDef with all of
19 | the variable ops converted into const ops containing the values of the
20 | variables.
21 |
22 | It's useful to do this when we need to load a single file in C++, especially in
23 | environments like mobile or embedded where we may not have access to the
24 | RestoreTensor ops and file loading calls that they rely on.
25 |
26 | An example of command-line usage is:
27 | bazel build tensorflow/python/tools:freeze_graph && \
28 | bazel-bin/tensorflow/python/tools/freeze_graph \
29 | --input_graph=some_graph_def.pb \
30 | --input_checkpoint=model.ckpt-8361242 \
31 | --output_graph=/tmp/frozen_graph.pb --output_node_names=softmax
32 |
33 | You can also look at freeze_graph_test.py for an example of how to use it.
34 |
35 | """
36 | from __future__ import absolute_import
37 | from __future__ import division
38 | from __future__ import print_function
39 |
40 | import tensorflow as tf
41 |
42 | from google.protobuf import text_format
43 | from tensorflow.python.framework import graph_util
44 |
45 |
46 | FLAGS = tf.app.flags.FLAGS
47 |
48 | tf.app.flags.DEFINE_string("input_graph", "",
49 | """TensorFlow 'GraphDef' file to load.""")
50 | tf.app.flags.DEFINE_string("input_saver", "",
51 | """TensorFlow saver file to load.""")
52 | tf.app.flags.DEFINE_string("input_checkpoint", "",
53 | """TensorFlow variables file to load.""")
54 | tf.app.flags.DEFINE_string("output_graph", "",
55 | """Output 'GraphDef' file name.""")
56 | tf.app.flags.DEFINE_boolean("input_binary", False,
57 | """Whether the input files are in binary format.""")
58 | tf.app.flags.DEFINE_string("output_node_names", "",
59 | """The name of the output nodes, comma separated.""")
60 | tf.app.flags.DEFINE_string("restore_op_name", "save/restore_all",
61 | """The name of the master restore operator.""")
62 | tf.app.flags.DEFINE_string("filename_tensor_name", "save/Const:0",
63 | """The name of the tensor holding the save path.""")
64 | tf.app.flags.DEFINE_boolean("clear_devices", True,
65 | """Whether to remove device specifications.""")
66 | tf.app.flags.DEFINE_string("initializer_nodes", "", "comma separated list of "
67 | "initializer nodes to run before freezing.")
68 |
69 |
70 | def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
71 | output_node_names, restore_op_name, filename_tensor_name,
72 | output_graph, clear_devices, initializer_nodes):
73 | """Converts all variables in a graph and checkpoint into constants."""
74 |
75 | if not tf.gfile.Exists(input_graph):
76 | print("Input graph file '" + input_graph + "' does not exist!")
77 | return -1
78 |
79 | if input_saver and not tf.gfile.Exists(input_saver):
80 | print("Input saver file '" + input_saver + "' does not exist!")
81 | return -1
82 |
83 | if not tf.gfile.Glob(input_checkpoint):
84 | print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
85 | return -1
86 |
87 | if not output_node_names:
88 | print("You need to supply the name of a node to --output_node_names.")
89 | return -1
90 |
91 | input_graph_def = tf.GraphDef()
92 | mode = "rb" if input_binary else "r"
93 | with tf.gfile.FastGFile(input_graph, mode) as f:
94 | if input_binary:
95 | input_graph_def.ParseFromString(f.read())
96 | else:
97 | text_format.Merge(f.read(), input_graph_def)
98 | # Remove all the explicit device specifications for this node. This helps to
99 | # make the graph more portable.
100 | if clear_devices:
101 | for node in input_graph_def.node:
102 | node.device = ""
103 | _ = tf.import_graph_def(input_graph_def, name="")
104 |
105 | with tf.Session() as sess:
106 | if input_saver:
107 | with tf.gfile.FastGFile(input_saver, mode) as f:
108 | saver_def = tf.train.SaverDef()
109 | if input_binary:
110 | saver_def.ParseFromString(f.read())
111 | else:
112 | text_format.Merge(f.read(), saver_def)
113 | saver = tf.train.Saver(saver_def=saver_def)
114 | saver.restore(sess, input_checkpoint)
115 | else:
116 | sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
117 | if initializer_nodes:
118 | sess.run(initializer_nodes)
119 | output_graph_def = graph_util.convert_variables_to_constants(
120 | sess, input_graph_def, output_node_names.split(","))
121 |
122 | with tf.gfile.GFile(output_graph, "wb") as f:
123 | f.write(output_graph_def.SerializeToString())
124 | print("%d ops in the final graph." % len(output_graph_def.node))
125 |
126 |
127 | def main(unused_args):
128 | freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary,
129 | FLAGS.input_checkpoint, FLAGS.output_node_names,
130 | FLAGS.restore_op_name, FLAGS.filename_tensor_name,
131 | FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes)
132 |
133 | if __name__ == "__main__":
134 | tf.app.run()
135 |
--------------------------------------------------------------------------------
/tools/freeze_graph_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests the graph freezing tool."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import os
21 |
22 | import tensorflow as tf
23 |
24 | from tensorflow.python.framework import test_util
25 | from tensorflow.python.tools import freeze_graph
26 |
27 |
28 | class FreezeGraphTest(test_util.TensorFlowTestCase):
29 |
30 | def testFreezeGraph(self):
31 |
32 | checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
33 | checkpoint_state_name = "checkpoint_state"
34 | input_graph_name = "input_graph.pb"
35 | output_graph_name = "output_graph.pb"
36 |
37 | # We'll create an input graph that has a single variable containing 1.0,
38 | # and that then multiplies it by 2.
39 | with tf.Graph().as_default():
40 | variable_node = tf.Variable(1.0, name="variable_node")
41 | output_node = tf.mul(variable_node, 2.0, name="output_node")
42 | sess = tf.Session()
43 | init = tf.initialize_all_variables()
44 | sess.run(init)
45 | output = sess.run(output_node)
46 | self.assertNear(2.0, output, 0.00001)
47 | saver = tf.train.Saver()
48 | saver.save(sess, checkpoint_prefix, global_step=0,
49 | latest_filename=checkpoint_state_name)
50 | tf.train.write_graph(sess.graph.as_graph_def(), self.get_temp_dir(),
51 | input_graph_name)
52 |
53 | # We save out the graph to disk, and then call the const conversion
54 | # routine.
55 | input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
56 | input_saver_def_path = ""
57 | input_binary = False
58 | input_checkpoint_path = checkpoint_prefix + "-0"
59 | output_node_names = "output_node"
60 | restore_op_name = "save/restore_all"
61 | filename_tensor_name = "save/Const:0"
62 | output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
63 | clear_devices = False
64 |
65 | freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
66 | input_binary, input_checkpoint_path,
67 | output_node_names, restore_op_name,
68 | filename_tensor_name, output_graph_path,
69 | clear_devices, "")
70 |
71 | # Now we make sure the variable is now a constant, and that the graph still
72 | # produces the expected result.
73 | with tf.Graph().as_default():
74 | output_graph_def = tf.GraphDef()
75 | with open(output_graph_path, "rb") as f:
76 | output_graph_def.ParseFromString(f.read())
77 | _ = tf.import_graph_def(output_graph_def, name="")
78 |
79 | self.assertEqual(4, len(output_graph_def.node))
80 | for node in output_graph_def.node:
81 | self.assertNotEqual("Variable", node.op)
82 |
83 | with tf.Session() as sess:
84 | output_node = sess.graph.get_tensor_by_name("output_node:0")
85 | output = sess.run(output_node)
86 | self.assertNear(2.0, output, 0.00001)
87 |
88 | if __name__ == "__main__":
89 | tf.test.main()
90 |
--------------------------------------------------------------------------------
/tools/graph_metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Gives estimates of computation and parameter sizes for a GraphDef.
17 |
18 | This script takes a GraphDef representing a network, and produces rough
19 | estimates of the number of floating-point operations needed to implement it and
20 | how many parameters are stored. You need to pass in the input size, and the
21 | results are only approximate, since it only calculates them for a subset of
22 | common operations.
23 |
24 | If you have downloaded the Inception graph for the label_image example, an
25 | example of using this script would be:
26 |
27 | bazel-bin/third_party/tensorflow/python/tools/graph_metrics \
28 | --graph tensorflow_inception_graph.pb \
29 | --statistics=weight_parameters,flops
30 |
31 | """
32 | from __future__ import absolute_import
33 | from __future__ import division
34 | from __future__ import print_function
35 |
36 | import locale
37 |
38 | import tensorflow as tf
39 |
40 | from google.protobuf import text_format
41 |
42 | from tensorflow.core.framework import graph_pb2
43 | from tensorflow.python.framework import ops
44 |
45 |
46 | FLAGS = tf.flags.FLAGS
47 |
48 | tf.flags.DEFINE_string("graph", "", """TensorFlow 'GraphDef' file to load.""")
49 | tf.flags.DEFINE_bool("input_binary", True,
50 | """Whether the input files are in binary format.""")
51 | tf.flags.DEFINE_string("input_layer", "Mul:0",
52 | """The name of the input node.""")
53 | tf.flags.DEFINE_integer("batch_size", 1,
54 | """The batch size to use for the calculations.""")
55 | tf.flags.DEFINE_string("statistics", "weight_parameters,flops",
56 | """Which statistic types to examine.""")
57 | tf.flags.DEFINE_string("input_shape_override", "",
58 | """If this is set, the comma-separated values will be"""
59 | """ used to set the shape of the input layer.""")
60 | tf.flags.DEFINE_boolean("print_nodes", False,
61 | """Whether to show statistics for each op.""")
62 |
63 |
64 | def print_stat(prefix, statistic_type, value):
65 | if value is None:
66 | friendly_value = "None"
67 | else:
68 | friendly_value = locale.format("%d", value, grouping=True)
69 | print("%s%s=%s" % (prefix, statistic_type, friendly_value))
70 |
71 |
72 | def main(unused_args):
73 | if not tf.gfile.Exists(FLAGS.graph):
74 | print("Input graph file '" + FLAGS.graph + "' does not exist!")
75 | return -1
76 | graph_def = graph_pb2.GraphDef()
77 | with open(FLAGS.graph, "rb") as f:
78 | if FLAGS.input_binary:
79 | graph_def.ParseFromString(f.read())
80 | else:
81 | text_format.Merge(f.read(), graph_def)
82 | statistic_types = FLAGS.statistics.split(",")
83 | if FLAGS.input_shape_override:
84 | input_shape_override = map(int, FLAGS.input_shape_override.split(","))
85 | else:
86 | input_shape_override = None
87 | total_stats, node_stats = calculate_graph_metrics(
88 | graph_def, statistic_types, FLAGS.input_layer, input_shape_override,
89 | FLAGS.batch_size)
90 | if FLAGS.print_nodes:
91 | for node in graph_def.node:
92 | for statistic_type in statistic_types:
93 | current_stats = node_stats[statistic_type][node.name]
94 | print_stat(node.name + "(" + node.op + "): ", statistic_type,
95 | current_stats.value)
96 | for statistic_type in statistic_types:
97 | value = total_stats[statistic_type].value
98 | print_stat("Total: ", statistic_type, value)
99 |
100 |
101 | def calculate_graph_metrics(graph_def, statistic_types, input_layer,
102 | input_shape_override, batch_size):
103 | """Looks at the performance statistics of all nodes in the graph."""
104 | _ = tf.import_graph_def(graph_def, name="")
105 | total_stats = {}
106 | node_stats = {}
107 | for statistic_type in statistic_types:
108 | total_stats[statistic_type] = ops.OpStats(statistic_type)
109 | node_stats[statistic_type] = {}
110 | # Make sure we get pretty-printed numbers with separators.
111 | locale.setlocale(locale.LC_ALL, "")
112 | with tf.Session() as sess:
113 | input_tensor = sess.graph.get_tensor_by_name(input_layer)
114 | input_shape_tensor = input_tensor.get_shape()
115 | if input_shape_tensor:
116 | input_shape = input_shape_tensor.as_list()
117 | else:
118 | input_shape = None
119 | if input_shape_override:
120 | input_shape = input_shape_override
121 | if input_shape is None:
122 | raise ValueError("""No input shape was provided on the command line,"""
123 | """ and the input op itself had no default shape, so"""
124 | """ shape inference couldn't be performed. This is"""
125 | """ required for metrics calculations.""")
126 | input_shape[0] = batch_size
127 | input_tensor.set_shape(input_shape)
128 | for node in graph_def.node:
129 | # Ensure that the updated input shape has been fully-propagated before we
130 | # ask for the statistics, since they may depend on the output size.
131 | op = sess.graph.get_operation_by_name(node.name)
132 | ops.set_shapes_for_outputs(op)
133 | for statistic_type in statistic_types:
134 | current_stats = ops.get_stats_for_node_def(sess.graph, node,
135 | statistic_type)
136 | node_stats[statistic_type][node.name] = current_stats
137 | total_stats[statistic_type] += current_stats
138 | return total_stats, node_stats
139 |
140 | if __name__ == "__main__":
141 | tf.app.run()
142 |
--------------------------------------------------------------------------------
/tools/graph_metrics_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests the graph metrics tool."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import tensorflow as tf
21 |
22 | from tensorflow.python.tools import graph_metrics
23 |
24 |
25 | class GraphMetricsTest(tf.test.TestCase):
26 |
27 | def testGraphMetrics(self):
28 | with tf.Graph().as_default():
29 | input_node = tf.placeholder(tf.float32, shape=[10, 20], name="input_node")
30 | weights_node = tf.constant(0.0,
31 | dtype=tf.float32,
32 | shape=[20, 5],
33 | name="weights_node")
34 | tf.matmul(input_node, weights_node, name="matmul_node")
35 | sess = tf.Session()
36 | graph_def = sess.graph.as_graph_def()
37 | statistic_types = ["weight_parameters", "flops"]
38 | total_stats, node_stats = graph_metrics.calculate_graph_metrics(
39 | graph_def, statistic_types, "input_node:0", None, 10)
40 | expected = {"weight_parameters": 100, "flops": 2000}
41 | for statistic_type in statistic_types:
42 | current_stats = node_stats[statistic_type]["matmul_node"]
43 | self.assertEqual(expected[statistic_type], current_stats.value)
44 | for statistic_type in statistic_types:
45 | current_stats = total_stats[statistic_type]
46 | self.assertEqual(expected[statistic_type], current_stats.value)
47 |
48 |
49 | if __name__ == "__main__":
50 | tf.test.main()
51 |
--------------------------------------------------------------------------------
/tools/inspect_checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """A simple script for inspect checkpoint files."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import sys
22 |
23 | import tensorflow as tf
24 |
25 | FLAGS = tf.app.flags.FLAGS
26 | tf.app.flags.DEFINE_string("file_name", './../tmp/doom_bs__act|sigmoid__bs|20__h|500|5|500__init|na__inp|cbd4__lr|0.0004__opt|AO/-3560', "Checkpoint filename")
27 | tf.app.flags.DEFINE_string("tensor_name", "", "Name of the tensor to inspect")
28 |
29 |
30 | def print_tensors_in_checkpoint_file(file_name, tensor_name):
31 | """Prints tensors in a checkpoint file.
32 |
33 | If no `tensor_name` is provided, prints the tensor names and shapes
34 | in the checkpoint file.
35 |
36 | If `tensor_name` is provided, prints the content of the tensor.
37 |
38 | Args:
39 | file_name: Name of the checkpoint file.
40 | tensor_name: Name of the tensor in the checkpoint file to print.
41 | """
42 | try:
43 | reader = tf.train.NewCheckpointReader(file_name)
44 | if not tensor_name:
45 | print(reader.debug_string().decode("utf-8"))
46 | else:
47 | print("tensor_name: ", tensor_name)
48 | print(reader.get_tensor(tensor_name))
49 | except Exception as e: # pylint: disable=broad-except
50 | print(str(e))
51 | if "corrupted compressed block contents" in str(e):
52 | print("It's likely that your checkpoint file has been compressed "
53 | "with SNAPPY.")
54 |
55 |
56 | def main(unused_argv):
57 | if not FLAGS.file_name:
58 | print(FLAGS.file_name)
59 | print("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
60 | "[--tensor_name=tensor_to_print]")
61 | sys.exit(1)
62 | else:
63 | print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name)
64 |
65 | if __name__ == "__main__":
66 |
67 | tf.app.run()
68 |
--------------------------------------------------------------------------------
/tools/strip_unused.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | r"""Removes unneeded nodes from a GraphDef file.
16 |
17 | This script is designed to help streamline models, by taking the input and
18 | output nodes that will be used by an application and figuring out the smallest
19 | set of operations that are required to run for those arguments. The resulting
20 | minimal graph is then saved out.
21 |
22 | The advantages of running this script are:
23 | - You may be able to shrink the file size.
24 | - Operations that are unsupported on your platform but still present can be
25 | safely removed.
26 | The resulting graph may not be as flexible as the original though, since any
27 | input nodes that weren't explicitly mentioned may not be accessible any more.
28 |
29 | An example of command-line usage is:
30 | bazel build tensorflow/python/tools:strip_unused && \
31 | bazel-bin/tensorflow/python/tools/strip_unused \
32 | --input_graph=some_graph_def.pb \
33 | --output_graph=/tmp/stripped_graph.pb \
34 | --input_node_names=input0
35 | --output_node_names=softmax
36 |
37 | You can also look at strip_unused_test.py for an example of how to use it.
38 |
39 | """
40 | from __future__ import absolute_import
41 | from __future__ import division
42 | from __future__ import print_function
43 | import copy
44 |
45 | import tensorflow as tf
46 |
47 | from google.protobuf import text_format
48 | from tensorflow.python.framework import graph_util
49 |
50 |
51 | FLAGS = tf.app.flags.FLAGS
52 |
53 | tf.app.flags.DEFINE_string("input_graph", "",
54 | """TensorFlow 'GraphDef' file to load.""")
55 | tf.app.flags.DEFINE_boolean("input_binary", False,
56 | """Whether the input files are in binary format.""")
57 | tf.app.flags.DEFINE_string("output_graph", "",
58 | """Output 'GraphDef' file name.""")
59 | tf.app.flags.DEFINE_string("input_node_names", "",
60 | """The name of the input nodes, comma separated.""")
61 | tf.app.flags.DEFINE_string("output_node_names", "",
62 | """The name of the output nodes, comma separated.""")
63 | tf.app.flags.DEFINE_integer("placeholder_type_enum",
64 | tf.float32.as_datatype_enum,
65 | """The AttrValue enum to use for placeholders.""")
66 |
67 |
68 | def strip_unused(input_graph, input_binary, output_graph, input_node_names,
69 | output_node_names, placeholder_type_enum):
70 | """Removes unused nodes from a graph."""
71 |
72 | if not tf.gfile.Exists(input_graph):
73 | print("Input graph file '" + input_graph + "' does not exist!")
74 | return -1
75 |
76 | if not output_node_names:
77 | print("You need to supply the name of a node to --output_node_names.")
78 | return -1
79 |
80 | input_graph_def = tf.GraphDef()
81 | mode = "rb" if input_binary else "r"
82 | with tf.gfile.FastGFile(input_graph, mode) as f:
83 | if input_binary:
84 | input_graph_def.ParseFromString(f.read())
85 | else:
86 | text_format.Merge(f.read(), input_graph_def)
87 |
88 | # Here we replace the nodes we're going to override as inputs with
89 | # placeholders so that any unused nodes that are inputs to them are
90 | # automatically stripped out by extract_sub_graph().
91 | input_node_names_list = input_node_names.split(",")
92 | inputs_replaced_graph_def = tf.GraphDef()
93 | for node in input_graph_def.node:
94 | if node.name in input_node_names_list:
95 | placeholder_node = tf.NodeDef()
96 | placeholder_node.op = "Placeholder"
97 | placeholder_node.name = node.name
98 | placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue(
99 | type=placeholder_type_enum))
100 | inputs_replaced_graph_def.node.extend([placeholder_node])
101 | else:
102 | inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])
103 |
104 | output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
105 | output_node_names.split(","))
106 |
107 | with tf.gfile.GFile(output_graph, "wb") as f:
108 | f.write(output_graph_def.SerializeToString())
109 | print("%d ops in the final graph." % len(output_graph_def.node))
110 |
111 |
112 | def main(unused_args):
113 | strip_unused(FLAGS.input_graph, FLAGS.input_binary, FLAGS.output_graph,
114 | FLAGS.input_node_names, FLAGS.output_node_names,
115 | FLAGS.placeholder_type_enum)
116 |
117 | if __name__ == "__main__":
118 | tf.app.run()
119 |
--------------------------------------------------------------------------------
/tools/strip_unused_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests the graph freezing tool."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import os
21 |
22 | import tensorflow as tf
23 |
24 | from tensorflow.python.framework import test_util
25 | from tensorflow.python.tools import strip_unused
26 |
27 |
28 | class FreezeGraphTest(test_util.TensorFlowTestCase):
29 |
30 | def testFreezeGraph(self):
31 | input_graph_name = "input_graph.pb"
32 | output_graph_name = "output_graph.pb"
33 |
34 | # We'll create an input graph that has a single constant containing 1.0,
35 | # and that then multiplies it by 2.
36 | with tf.Graph().as_default():
37 | constant_node = tf.constant(1.0, name="constant_node")
38 | wanted_input_node = tf.sub(constant_node, 3.0, name="wanted_input_node")
39 | output_node = tf.mul(wanted_input_node, 2.0, name="output_node")
40 | tf.add(output_node, 2.0, name="later_node")
41 | sess = tf.Session()
42 | output = sess.run(output_node)
43 | self.assertNear(-4.0, output, 0.00001)
44 | tf.train.write_graph(sess.graph.as_graph_def(), self.get_temp_dir(),
45 | input_graph_name)
46 |
47 | # We save out the graph to disk, and then call the const conversion
48 | # routine.
49 | input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
50 | input_binary = False
51 | input_node_names = "wanted_input_node"
52 | output_node_names = "output_node"
53 | output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
54 |
55 | strip_unused.strip_unused(input_graph_path, input_binary, output_graph_path,
56 | input_node_names, output_node_names,
57 | tf.float32.as_datatype_enum)
58 |
59 | # Now we make sure the variable is now a constant, and that the graph still
60 | # produces the expected result.
61 | with tf.Graph().as_default():
62 | output_graph_def = tf.GraphDef()
63 | with open(output_graph_path, "rb") as f:
64 | output_graph_def.ParseFromString(f.read())
65 | _ = tf.import_graph_def(output_graph_def, name="")
66 |
67 | self.assertEqual(3, len(output_graph_def.node))
68 | for node in output_graph_def.node:
69 | self.assertNotEqual("Add", node.op)
70 | self.assertNotEqual("Sub", node.op)
71 |
72 | with tf.Session() as sess:
73 | input_node = sess.graph.get_tensor_by_name("wanted_input_node:0")
74 | output_node = sess.graph.get_tensor_by_name("output_node:0")
75 | output = sess.run(output_node, feed_dict={input_node: [10.0]})
76 | self.assertNear(20.0, output, 0.00001)
77 |
78 | if __name__ == "__main__":
79 | tf.test.main()
80 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import time
3 | from matplotlib import pyplot as plt
4 | import os
5 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
6 | import collections
7 | import tensorflow as tf
8 | from scipy import misc
9 | import re
10 | import sys
11 | import subprocess as sp
12 | import warnings
13 | import functools
14 | import scipy
15 | import random
16 |
17 |
18 |
19 | IMAGE_FOLDER = './img/'
20 | TEMP_FOLDER = './tmp/'
21 | EPOCH_THRESHOLD = 4
22 | FLAGS = tf.app.flags.FLAGS
23 |
24 | _start_time = None
25 |
26 |
27 | # CONSOLE OPERATIONS
28 |
29 |
30 | def reset_start_time():
31 | global _start_time
32 | _start_time = None
33 |
34 |
35 | def _get_time_offset():
36 | global _start_time
37 | time = datetime.datetime.now()
38 | if _start_time is None:
39 | _start_time = time
40 | return '\t\t'
41 | sec = (time - _start_time).total_seconds()
42 | res = '(+%d)\t' % sec if sec < 60 else '(+%d:%02d)\t' % (sec/60, sec%60)
43 | return res
44 |
45 |
46 | def print_time(*args, same_line=False):
47 | string = ''
48 | for a in args:
49 | string += str(a) + ' '
50 | time = datetime.datetime.now().time().strftime('%H:%M:%S')
51 | offset = _get_time_offset()
52 | res = '%s%s %s' % (str(time), offset, str(string))
53 | print_color(res, same_line=same_line)
54 |
55 |
56 | def print_info(string, color=32, same_line=False):
57 | print_color('\t' + str(string), color=color, same_line=same_line)
58 |
59 |
60 | same_line_prev = None
61 |
62 |
63 | def print_color(string, color=33, same_line=False):
64 | global same_line_prev
65 | res = '%c[1;%dm%s%c[0m' % (27, color, str(string), 27)
66 | if same_line:
67 | print('\r ' +
68 | ' ', end=' ')
69 | print('\r' + res, end=' ')
70 | else:
71 | # if same_line_prev:
72 | # print('\n')
73 | print(res)
74 | same_line_prev = same_line
75 |
76 |
77 | def mnist_select_n_classes(train_images, train_labels, num_classes, min=None, scale=1.0):
78 | result_images, result_labels = [], []
79 | for i, j in zip(train_images, train_labels):
80 | if np.sum(j[0:num_classes]) > 0:
81 | result_images.append(i)
82 | result_labels.append(j[0:num_classes])
83 | inputs = np.asarray(result_images)
84 |
85 | inputs *= scale
86 | if min is not None:
87 | inputs = inputs - np.min(inputs) + min
88 | return inputs, np.asarray(result_labels)
89 |
90 |
91 | # IMAGE OPERATIONS
92 |
93 |
94 | def _save_image(name='image', save_params=None, image=None):
95 | if save_params is not None and 'e' in save_params and save_params['e'] < EPOCH_THRESHOLD:
96 | print_info('IMAGE: output is not saved. epochs %d < %d' % (save_params['e'], EPOCH_THRESHOLD), color=31)
97 | return
98 |
99 | file_name = name if save_params is None else to_file_name(save_params)
100 | file_name += '.png'
101 | name = os.path.join(FLAGS.save_path, file_name)
102 |
103 | if image is not None:
104 | misc.imsave(name, arr=image, format='png')
105 |
106 |
107 | def _show_picture(pic):
108 | fig = plt.figure()
109 | size = fig.get_size_inches()
110 | fig.set_size_inches(size[0], size[1] * 2, forward=True)
111 | plt.imshow(pic, cmap='Greys_r')
112 |
113 |
114 | def concat_images(im1, im2, axis=0):
115 | if im1 is None:
116 | return im2
117 | return np.concatenate((im1, im2), axis=axis)
118 |
119 |
120 | def _reconstruct_picture_line(pictures, shape):
121 | line_picture = None
122 | for _, img in enumerate(pictures):
123 | if len(img.shape) == 1:
124 | img = (np.reshape(img, shape))
125 | if len(img.shape) == 3 and img.shape[2] == 1:
126 | img = (np.reshape(img, (img.shape[0], img.shape[1])))
127 | line_picture = concat_images(line_picture, img)
128 | return line_picture
129 |
130 |
131 | def show_plt():
132 | plt.show()
133 |
134 |
135 | def _construct_img_shape(img):
136 | assert int(np.sqrt(img.shape[0])) == np.sqrt(img.shape[0])
137 | return int(np.sqrt(img.shape[0])), int(np.sqrt(img.shape[0])), 1
138 |
139 |
140 | def images_to_uint8(func):
141 | def normalize(arr):
142 | # if type(arr) == np.ndarray and arr.dtype != np.uint8 and len(arr.shape) >= 3:
143 | if type(arr) == np.ndarray and len(arr.shape) >= 3:
144 | if np.min(arr) < 0:
145 | print('image array normalization: negative values')
146 | if np.max(arr) < 10:
147 | arr *= 255
148 | if arr.shape[-1] == 4 or arr.shape[-1] == 2:
149 | old_shape = arr.shape
150 | arr = arr[..., :arr.shape[-1]-1]
151 | return arr.astype(np.uint8)
152 | return arr
153 |
154 | def func_wrapper(*args, **kwargs):
155 | new_args = [normalize(el) for el in args]
156 | new_kwargs = {k: normalize(kwargs[k]) for _, k in enumerate(kwargs)}
157 | return func(*tuple(new_args), **new_kwargs)
158 | return func_wrapper
159 |
160 |
161 | def fig2buf(fig):
162 | fig.canvas.draw()
163 | return fig.canvas.tostring_rgb()
164 |
165 |
166 | def fig2rgb_array(fig, expand=True):
167 | fig.canvas.draw()
168 | buf = fig.canvas.tostring_rgb()
169 | ncols, nrows = fig.canvas.get_width_height()
170 | shape = (nrows, ncols, 3) if not expand else (1, nrows, ncols, 3)
171 | return np.fromstring(buf, dtype=np.uint8).reshape(shape)
172 |
173 |
174 | @images_to_uint8
175 | def reconstruct_images_epochs(epochs, original=None, save_params=None, img_shape=None):
176 | full_picture = None
177 | img_shape = img_shape if img_shape is not None else _construct_img_shape(epochs[0][0])
178 |
179 | # print(original.dtype, epochs.dtype, np.max(original), np.max(epochs))
180 |
181 | if original.dtype != np.uint8:
182 | original = (original * 255).astype(np.uint8)
183 | if epochs.dtype != np.uint8:
184 | epochs = (epochs * 255).astype(np.uint8)
185 |
186 | # print('image reconstruction: ', original.dtype, epochs.dtype, np.max(original), np.max(epochs))
187 |
188 | if original is not None and epochs is not None and len(epochs) >= 3:
189 | min_ref, max_ref = np.min(original), np.max(original)
190 | print_info('epoch avg: (original: %s) -> %s' % (
191 | str(np.mean(original)), str((np.mean(epochs[0]), np.mean(epochs[1]), np.mean(epochs[2])))))
192 | print_info('reconstruction char. in epochs (min, max)|original: (%f %f)|(%f %f)' % (
193 | np.min(epochs[1:]), np.max(epochs), min_ref, max_ref))
194 |
195 | if epochs is not None:
196 | for _, epoch in enumerate(epochs):
197 | full_picture = concat_images(full_picture, _reconstruct_picture_line(epoch, img_shape), axis=1)
198 | if original is not None:
199 | full_picture = concat_images(full_picture, _reconstruct_picture_line(original, img_shape), axis=1)
200 | _show_picture(full_picture)
201 | _save_image(save_params=save_params, image=full_picture)
202 |
203 |
204 | def model_to_file_name(FLAGS, folder=None, ext=None):
205 | postfix = '' if len(FLAGS.postfix) == 0 else '_%s' % FLAGS.postfix
206 | name = '%s.%s__i_%s%s' % (FLAGS.model, FLAGS.net.replace('-', '_'), FLAGS.input_name, postfix)
207 | if ext:
208 | name += '.' + ext
209 | if folder:
210 | name = os.path.join(folder, name)
211 | return name
212 |
213 |
214 | def mkdir(folders):
215 | if isinstance(folders, str):
216 | folders = [folders]
217 | for _, folder in enumerate(folders):
218 | if not os.path.exists(folder):
219 | os.mkdir(folder)
220 |
221 |
222 | def configure_folders(FLAGS):
223 | folder_name = model_to_file_name(FLAGS) + '/'
224 | FLAGS.save_path = os.path.join(TEMP_FOLDER, folder_name)
225 | FLAGS.logdir = FLAGS.save_path
226 | print_color(os.path.abspath(FLAGS.logdir))
227 | mkdir([TEMP_FOLDER, IMAGE_FOLDER, FLAGS.save_path, FLAGS.logdir])
228 |
229 | with open(os.path.join(FLAGS.save_path, '!note.txt'), "a") as f:
230 | f.write('\n' + ' '.join(sys.argv) + '\n')
231 | f.write(print_flags(FLAGS, print=False))
232 | if len(FLAGS.comment) > 0:
233 | f.write('\n\n%s\n' % FLAGS.comment)
234 |
235 |
236 | def get_files(folder="./visualizations/", filter=None):
237 | all = []
238 | for root, dirs, files in os.walk(folder):
239 | # print(root, dirs, files)
240 | if filter:
241 | files = [x for x in files if re.match(filter, x)]
242 | all += files
243 | return [os.path.join(folder, x) for x in all]
244 |
245 |
246 | def get_latest_file(folder="./visualizations/", filter=None):
247 | latest_file, latest_mod_time = None, None
248 | for root, dirs, files in os.walk(folder):
249 | # print(root, dirs, files)
250 | if filter:
251 | files = [x for x in files if re.match(filter, x)]
252 | # print('\n\r'.join(files))
253 | for file in files:
254 | file_path = os.path.join(root, file)
255 | modification_time = os.path.getmtime(file_path)
256 | if not latest_mod_time or modification_time > latest_mod_time:
257 | latest_mod_time = modification_time
258 | latest_file = file_path
259 | if latest_file is None:
260 | print_info('Could not find file matching %s' % str(filter))
261 | return latest_file
262 |
263 |
264 | def concatenate(x, y, take=None):
265 | """
266 | Stitches two np arrays together until maximum length, when specified
267 | """
268 | if take is not None and x is not None and len(x) >= take:
269 | return x
270 | if x is None:
271 | res = y
272 | else:
273 | res = np.concatenate((x, y))
274 | return res[:take] if take is not None else res
275 |
276 |
277 | # MISC
278 |
279 |
280 | def list_object_attributes(obj):
281 | print('Object type: %s\t\tattributes:' % str(type(obj)))
282 | print('\n\t'.join(map(str, obj.__dict__.keys())))
283 |
284 |
285 | def print_list(list):
286 | print('\n'.join(map(str, list)))
287 |
288 |
289 | def print_float_list(list, format='%.4f'):
290 | return ' '.join(map(lambda x: format%x, list))
291 |
292 |
293 | def timeit(method):
294 | def timed(*args, **kw):
295 | ts = time.time()
296 | result = method(*args, **kw)
297 | te = time.time()
298 | print('%r %2.2f sec' % (method.__name__, te-ts))
299 | return result
300 | return timed
301 |
302 |
303 | def deprecated(func):
304 | '''This is a decorator which can be used to mark functions
305 | as deprecated. It will result in a warning being emitted
306 | when the function is used.'''
307 |
308 | @functools.wraps(func)
309 | def new_func(*args, **kwargs):
310 | warnings.warn_explicit(
311 | "Call to deprecated function {}.".format(func.__name__),
312 | category=DeprecationWarning,
313 | filename=func.func_code.co_filename,
314 | lineno=func.func_code.co_firstlineno + 1
315 | )
316 | return func(*args, **kwargs)
317 |
318 | return new_func
319 |
320 |
321 | import numpy as np
322 | ACCEPTABLE_AVAILABLE_MEMORY = 1024
323 |
324 |
325 | def mask_busy_gpus(leave_unmasked=1, random=True):
326 | try:
327 | command = "nvidia-smi --query-gpu=memory.free --format=csv"
328 | memory_free_info = _output_to_list(sp.check_output(command.split()))[1:]
329 | memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
330 | available_gpus = [i for i, x in enumerate(memory_free_values) if x > ACCEPTABLE_AVAILABLE_MEMORY]
331 |
332 | if len(available_gpus) < leave_unmasked:
333 | print('Found only %d usable GPUs in the system' % len(available_gpus))
334 | exit(0)
335 |
336 | if random:
337 | available_gpus = np.asarray(available_gpus)
338 | np.random.shuffle(available_gpus)
339 |
340 | # update CUDA variable
341 | gpus = available_gpus[:leave_unmasked]
342 | setting = ','.join(map(str, gpus))
343 | os.environ["CUDA_VISIBLE_DEVICES"] = setting
344 | print('Left next %d GPU(s) unmasked: [%s] (from %s available)'
345 | % (leave_unmasked, setting, str(available_gpus)))
346 | except FileNotFoundError as e:
347 | print('"nvidia-smi" is probably not installed. GPUs are not masked')
348 | print(e)
349 | except sp.CalledProcessError as e:
350 | print("Error on GPU masking:\n", e.output)
351 |
352 |
353 | def _output_to_list(output):
354 | return output.decode('ascii').split('\n')[:-1]
355 |
356 |
357 | def get_gpu_free_session(memory_fraction=0.1):
358 | import tensorflow as tf
359 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=memory_fraction)
360 | return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
361 |
362 |
363 | def parse_params():
364 | params = {}
365 | for i, param in enumerate(sys.argv):
366 | if '-' in param:
367 | params[param[1:]] = sys.argv[i+1]
368 | print(params)
369 | return params
370 |
371 |
372 | def print_flags(FLAGS, print=True):
373 | x = FLAGS.input_path
374 | res = 'FLAGS:'
375 | for i in sorted(FLAGS.__dict__['__flags'].items()):
376 | item = str(i)[2:-1].split('\', ')
377 | res += '\n%20s \t%s' % (item[0] + ':', item[1])
378 | if print:
379 | print_info(res)
380 | return res
381 |
382 |
383 | def _abbreviate_string(value):
384 | str_value = str(value)
385 | abbr = [letter for letter in str_value if letter.isupper()]
386 | if len(abbr) > 1:
387 | return ''.join(abbr)
388 |
389 | if len(str_value.split('_')) > 2:
390 | parts = str_value.split('_')
391 | letters = ''.join(x[0] for x in parts)
392 | return letters
393 | return value
394 |
395 |
396 | def to_file_name(obj, folder=None, ext=None, append_timestamp=False):
397 | name, postfix = '', ''
398 | od = collections.OrderedDict(sorted(obj.items()))
399 | for _, key in enumerate(od):
400 | value = obj[key]
401 | if value is None:
402 | value = 'na'
403 | #FUNC and OBJECTS
404 | if 'function' in str(value):
405 | value = str(value).split()[1].split('.')[0]
406 | parts = value.split('_')
407 | if len(parts) > 1:
408 | value = ''.join(list(map(lambda x: x.upper()[0], parts)))
409 | elif ' at ' in str(value):
410 | value = (str(value).split()[0]).split('.')[-1]
411 | value = _abbreviate_string(value)
412 | elif isinstance(value, type):
413 | value = _abbreviate_string(value.__name__)
414 | # FLOATS
415 | if isinstance(value, float) or isinstance(value, np.float32):
416 | if value < 0.0001:
417 | value = '%.6f' % value
418 | elif value > 1000000:
419 | value = '%.0f' % value
420 | else:
421 | value = '%.4f' % value
422 | value = value.rstrip('0')
423 | #INTS
424 | if isinstance(value, int):
425 | value = '%02d' % value
426 | #LIST
427 | if isinstance(value, list):
428 | value = '|'.join(map(str, value))
429 |
430 | truncate_threshold = 20
431 | value = _abbreviate_string(value)
432 | if len(value) > truncate_threshold:
433 | print_info('truncating this: %s %s' % (key, value))
434 | value = value[0:20]
435 |
436 | if 'suf' in key or 'postf' in key:
437 | continue
438 |
439 | name += '__%s|%s' % (key, str(value))
440 |
441 | if 'suf' in obj:
442 | prefix_value = obj['suf']
443 | else:
444 | prefix_value = FLAGS.suffix
445 | if 'postf' in obj:
446 | prefix_value += '_%s' % obj['postf']
447 | name = prefix_value + name
448 |
449 | if ext:
450 | name += '.' + ext
451 | if folder:
452 | name = os.path.join(folder, name)
453 | return name
454 |
455 |
456 | def dict_to_ordereddict(dict):
457 | return collections.OrderedDict(sorted(dict.items()))
458 |
459 |
460 | def configure_folders_2(FLAGS, meta):
461 | folder_meta = meta.copy()
462 | folder_meta.pop('init')
463 | folder_meta.pop('lr')
464 | folder_meta.pop('opt')
465 | folder_meta.pop('bs')
466 | folder_name = to_file_name(folder_meta) + '/'
467 | checkpoint_folder = os.path.join(TEMP_FOLDER, folder_name)
468 | log_folder = os.path.join(checkpoint_folder, 'log')
469 | mkdir([TEMP_FOLDER, IMAGE_FOLDER, checkpoint_folder, log_folder])
470 | FLAGS.save_path = checkpoint_folder
471 | FLAGS.logdir = log_folder
472 | return checkpoint_folder, log_folder
473 |
474 |
475 | # precision/recall evaluation
476 |
477 | def evaluate_precision_recall(y, target, labels):
478 | import sklearn.metrics as metrics
479 | target = target[:len(y)]
480 | num_classes = max(target) + 1
481 | results = []
482 | for i in range(num_classes):
483 | class_target = _extract_single_class(i, target)
484 | class_y = _extract_single_class(i, y)
485 |
486 | results.append({
487 | 'precision': metrics.precision_score(class_target, class_y),
488 | 'recall': metrics.recall_score(class_target, class_y),
489 | 'f1': metrics.f1_score(class_target, class_y),
490 | 'fraction': sum(class_target)/len(target),
491 | '#of_class': int(sum(class_target)),
492 | 'label': labels[i],
493 | 'label_id': i
494 | # 'tp': tp
495 | })
496 | print('%d/%d' % (i, num_classes), results[-1])
497 | accuracy = metrics.accuracy_score(target, y)
498 | return accuracy, results
499 |
500 |
501 | def _extract_single_class(i, classes):
502 | res, i = classes + 1, i + 1
503 | res[res != i] = 0
504 | res = np.asarray(res)
505 | res = res / i
506 | return res
507 |
508 |
509 |
510 | def print_relevance_info(relevance, prefix='', labels=None):
511 | labels = labels if labels is not None else np.arange(len(relevance))
512 | separator = '\n\t' if len(relevance) > 3 else ' '
513 | result = '%s format: [" label": f1_score (precision recall) label_percentage]' % prefix
514 | format = '\x1B[0m%s\t"%25s":\x1B[31;40m%.2f\x1B[0m (%.2f %.2f) %d%%'
515 | for i, label_relevance in enumerate(relevance):
516 | result += format % (separator,
517 | str(labels[i]),
518 | label_relevance['f1'],
519 | label_relevance['precision'],
520 | label_relevance['recall'],
521 | int(label_relevance['fraction']*10000)/100.
522 | )
523 | print(result)
524 |
525 |
526 | def disalbe_tensorflow_warnings():
527 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
528 |
529 |
530 | @timeit
531 | def images_to_sprite(arr, path=None):
532 | assert len(arr) <= 100*100
533 | arr = arr[...,:3]
534 | resized = [scipy.misc.imresize(x[..., :3], size=[80, 80]) for x in arr]
535 | base = np.zeros([8000, 8000, 3], np.uint8)
536 |
537 | for i in range(100):
538 | for j in range(100):
539 | index = j+100*i
540 | if index < len(resized):
541 | base[80*i:80*i+80, 80*j:80*j+80] = resized[index]
542 | scipy.misc.imsave(path, base)
543 |
544 |
545 | def generate_tsv(num, path):
546 | with open(path, mode='w') as f:
547 | [f.write('%d\n' % i) for i in range(num)]
548 |
549 |
550 | def paste_patch(patch, base_size=40, upper_half=True):
551 | channels = patch.shape[-1]
552 | base = np.zeros((base_size, base_size, channels), dtype=np.uint8)
553 |
554 | position_x = random.randint(0, base_size - patch.shape[0])
555 | position_y = random.randint(0, base_size / 2 - patch.shape[1])
556 | if not upper_half: position_y += int(base_size / 2)
557 |
558 | base[position_x:position_x + patch.shape[0], position_y:position_y + patch.shape[1], :] = patch
559 | return base
560 |
561 |
562 | if __name__ == '__main__':
563 | data = []
564 | for i in range(10):
565 | data.append((str(i), np.random.rand(1000)))
566 |
567 |
568 | mask_busy_gpus()
--------------------------------------------------------------------------------
/video_builder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import autoencoder as ae
4 | import os
5 | import sys
6 | import utils as ut
7 | import visualization as vis
8 | import matplotlib.pyplot as plt
9 |
10 | FLAGS = tf.app.flags.FLAGS
11 |
12 |
13 | # 1. load model
14 | # 2. convert to static model
15 | # 3. fetch data
16 | # 4. feed one by one into the model -> create projection
17 |
18 | def restore_flags(path):
19 | x = FLAGS.input_path
20 | log_file = os.path.join(path, '!note.txt')
21 |
22 | with open(log_file, 'r') as f:
23 | lines = f.readlines()
24 |
25 | # collect flags
26 | flag_stored = {}
27 | for l in lines:
28 | if ': \t' in l:
29 | parts = l[:-1].split(': \t')
30 | parts[0].strip()
31 | key = parts[0].strip()
32 | val = parts[1]
33 | flag_stored[key] = val
34 |
35 | print(flag_stored)
36 | print(FLAGS.__dict__['__flags'])
37 |
38 | flag_current = FLAGS.__dict__['__flags']
39 | for k in flag_stored.keys():
40 | if k in flag_current:
41 | # fix int issue '10.0' => '10'
42 | if type(flag_current[k]) == int:
43 | flag_stored[k] = flag_stored[k].split('.')[0]
44 |
45 | if type(flag_current[k]) == str:
46 | flag_stored[k] = flag_stored[k][1:-1]
47 | print(flag_stored[k], k)
48 |
49 | type_friendly_val = type(flag_current[k])(flag_stored[k])
50 | # print(k, type(dest[k]), flags[k], type(dest[k])(flags[k]))
51 | flag_current[k] = type_friendly_val
52 |
53 |
54 | # def print_video(ecn, img, reco):
55 |
56 |
57 |
58 |
59 | def data_to_img(img):
60 | h, w, c = img.shape
61 | base = np.zeros((h, 2*w, 3))
62 | base[:, :w, :] = img[:,:,:3]
63 | base[:, w:] = np.expand_dims(img[:,:,3], axis=2)
64 | return base
65 |
66 |
67 | def _show_image(img, original=True):
68 | # index = 322 if original else 326
69 | index = (0, 1) if original else (4, 1)
70 | # ax = plt.subplot(index)
71 | ax = plt.subplot2grid((7, 2), index, rowspan=3)
72 | ax.imshow(img)
73 | ax.set_title('Original' if original else 'Reconstruction')
74 | ax.axis('off')
75 |
76 |
77 | TAIL_LENGTH = 30
78 |
79 |
80 | def animate(cur, enc, img, reco):
81 | fig = vis.get_figure()
82 |
83 | original = data_to_img(img[cur])
84 | reconstr = data_to_img(reco[cur])
85 |
86 | _show_image(original)
87 | _show_image(reconstr, original=False)
88 |
89 | # animation
90 | ax = plt.subplot2grid((7, 2), (0, 0), rowspan=7, projection='3d')
91 | ax.set_title('Trajectory')
92 | ax.axes.get_xaxis().set_ticks([])
93 | ax.axes.get_yaxis().set_ticks([])
94 | ax.set_zticks([])
95 |
96 | if enc.shape[1] > 3:
97 | enc = enc[:, :4]
98 |
99 | # white
100 | data = enc[cur:]
101 | ax.scatter(data[:, 0], data[:, 1], data[:, 2], c='w', s=1, zorder=15)
102 | # old
103 | tail_len = max(0, cur - TAIL_LENGTH)
104 | # print(tail_len)
105 | if tail_len > 0:
106 | data = enc[:tail_len]
107 | ax.scatter(data[:, 0], data[:, 1], data[:, 2], c='black', zorder=10, s=1,)
108 | # recent
109 | for i in range(0, -TAIL_LENGTH, -1):
110 | j = i + cur
111 | if j >= 0:
112 | # print(cur, i)
113 | data = enc[j]
114 | ax.scatter(data[0], data[1], data[2], c='b', s=(i+TAIL_LENGTH)/5, zorder=5)
115 |
116 | plt.show()
117 |
118 | if __name__ == '__main__':
119 | path = os.getcwd()
120 | path = '/mnt/code/vd/TensorFlow_DCIGN/tmp/pred.16c3s2_32c3s2_32c3s2_16c3_f3__i_romb8.5.6'
121 |
122 | # write new flags: batch_size, path, save_path
123 | restore_flags(path)
124 | FLAGS.batch_size = 1
125 |
126 | path = FLAGS.input_path if len(FLAGS.test_path) == 0 else FLAGS.test_path
127 | path = '../../' + path
128 | path = sys.argv[-1] if len(sys.argv) == 3 else path
129 | FLAGS.test_path, FLAGS.input_path = path, path
130 | FLAGS.save_path = os.getcwd()
131 | FLAGS.model = 'ae'
132 | FLAGS.new_blur = False
133 | print(FLAGS.net, FLAGS.new_blur, FLAGS.test_path, FLAGS.input_path, os.getcwd())
134 |
135 | # run inference
136 | model = ae.Autoencoder(need_forlders=False)
137 | enc, reco = model.inference(max=100)
138 | img = model.test_set
139 |
140 | # enc, img, reco = np.arange(0, 364*3).reshape((364, 3)), np.random.rand(364, 80, 160, 4), np.random.rand(364, 80, 160, 4)
141 |
142 | for i in range(len(enc)):
143 | animate(i+100, enc, img, reco)
144 |
145 | # (364, 3)(364, 80, 160, 4)
146 |
147 |
--------------------------------------------------------------------------------
/visualization.py:
--------------------------------------------------------------------------------
1 | from sklearn.manifold import TSNE
2 | import sklearn.manifold as mn
3 | import matplotlib.pyplot as plt
4 | import sklearn.metrics.pairwise as pw
5 | import scipy.spatial.distance as dist
6 | import numpy as np
7 | from mpl_toolkits.mplot3d import Axes3D
8 | import os
9 | import sys
10 | import utils as ut
11 | import time
12 | import tensorflow as tf
13 | import matplotlib.ticker as ticker
14 |
15 |
16 | # Next line to silence pyflakes. This import is needed.
17 | Axes3D
18 |
19 | # colors = ['grey', 'red', 'magenta']
20 |
21 | FLAGS = tf.app.flags.FLAGS
22 | COLOR_MAP = plt.cm.Spectral
23 | PICKER_SENSITIVITY = 5
24 |
25 |
26 | def scatter(plot, data, is3d, colors):
27 | if is3d:
28 | plot.scatter(data[:, 0], data[:, 1], data[:, 2],
29 | marker='.',
30 | c=colors,
31 | cmap=plt.cm.Spectral,
32 | picker=PICKER_SENSITIVITY)
33 | else:
34 | plot.scatter(data[:, 0], data[:, 1],
35 | c=colors,
36 | cmap=plt.cm.Spectral,
37 | picker=PICKER_SENSITIVITY)
38 |
39 |
40 | def print_data_only(data, file_name, fig=None, interactive=False):
41 | fig = get_figure(fig)
42 | subplot_number = 121 if fig is not None else 111
43 | fig.set_size_inches(fig.get_size_inches()[0] * 2, fig.get_size_inches()[1] * 1)
44 |
45 | colors = _build_radial_colors(len(data))
46 | if data.shape[1] > 2:
47 | subplot = plt.subplot(subplot_number, projection='3d')
48 | subplot.scatter(data[:, 0], data[:, 1], data[:, 2], c=colors,
49 | cmap=COLOR_MAP, picker=PICKER_SENSITIVITY)
50 | subplot.plot(data[:, 0], data[:, 1], data[:, 2])
51 | else:
52 | subplot = plt.subplot(subplot_number)
53 | subplot.scatter(data[:, 0], data[:, 1], c=colors,
54 | cmap=COLOR_MAP, picker=PICKER_SENSITIVITY)
55 | if not interactive:
56 | save_fig(file_name, fig)
57 |
58 |
59 | def create_gif_from_folder(folder):
60 | # fig = plt.figure()
61 | # gif_folder = os.path.join(FLAGS.save_path, 'gif')
62 | # if not os.path.exists(gif_folder):
63 | # os.mkdir(gif_folder)
64 | # epoch = file_name.split('_e|')[-1].split('_')[0]
65 | # gif_path = os.path.join(gif_folder, epoch)
66 | # subplot = plt.subplot(111, projection='3d')
67 | # subplot.scatter(data[0], data[1], data[2], c=colors, cmap=color_map)
68 | # save_fig(file_name)
69 | pass
70 |
71 |
72 | def manual_pca(data, std_threshold=0.01, num_threshold=3):
73 | """remove meaningless dimensions"""
74 | std = data[0:300].std(axis=0)
75 |
76 | order = np.argsort(std)[::-1]
77 | # order = np.arange(0, data.shape[1]).astype(np.int32)
78 | std = std[order]
79 | # filter components by STD but take at least 3
80 |
81 | meaningless = [order[i] for i, x in enumerate(std) if x <= std_threshold]
82 | if any(meaningless) and data.shape[1] > 3:
83 | # ut.print_info('meaningless dimensions on visualization: %s' % str(meaningless))
84 | pass
85 |
86 | order = [order[i] for i, x in enumerate(std) if x > std_threshold or i < num_threshold]
87 | order.sort()
88 | return data[:, order]
89 |
90 |
91 | def _needs_hessian(manifold):
92 | if hasattr(manifold, 'dissimilarity') and manifold.dissimilarity == 'precomputed':
93 | return True
94 | if hasattr(manifold, 'metric') and manifold.metric == 'precomputed':
95 | return True
96 | return False
97 |
98 |
99 | @ut.images_to_uint8
100 | @ut.timeit
101 | def visualize_encoding(encodings, folder=None, meta={}, original=None, reconstruction=None):
102 | if np.max(original) < 10:
103 | original = (original * 255).astype(np.uint8)
104 | # print('np', np.max(original), np.max(reconstruction), np.min(original), np.min(reconstruction),
105 | # original.dtype, reconstruction.dtype)
106 | file_path = None
107 | if folder:
108 | meta['postfix'] = 'pca'
109 | file_path = ut.to_file_name(meta, folder, 'jpg')
110 | encodings = manual_pca(encodings)
111 |
112 | if original is not None:
113 | assert len(original) == len(reconstruction)
114 | fig = get_figure()
115 |
116 | # print('reco max:', np.max(reconstruction))
117 | column_picture, height = _stitch_images(original, reconstruction)
118 | subplot, proportion = (122, 1) if encodings.shape[1] <= 3 else (155, 3)
119 | picture = _reshape_column_image(column_picture, height, proportion=proportion)
120 | if picture.shape[-1] == 1:
121 | picture = picture.squeeze()
122 | plt.subplot(subplot).set_title("Original/reconstruction")
123 | plt.subplot(subplot).imshow(picture)
124 | plt.subplot(subplot).axis('off')
125 |
126 | visualize_encodings(encodings, file_name=file_path, fig=fig, grid=(3, 5), skip_every=5)
127 | else:
128 | visualize_encodings(encodings, file_name=file_path)
129 |
130 |
131 | # @ut.timeit
132 | def plot_encoding_crosssection(encodings, file_path, original=None, reconstruction=None, interactive=False):
133 | # print(encodings.shape)
134 | # print(original.shape)
135 | # print(reconstruction.shape)
136 | # encodings = manual_pca(encodings)
137 |
138 | fig = get_figure()
139 | if original is not None:
140 | assert len(original) == len(reconstruction), (len(original), len(reconstruction))
141 | subplot, proportion = visualize_cross_section_with_reco(encodings)
142 | picture = stitch_side_by_side(original, reconstruction, proportion)
143 | subplot.imshow(picture)
144 | else:
145 | visualize_cross_section(encodings)
146 | if not interactive:
147 | save_fig(file_path, fig)
148 | return fig
149 |
150 |
151 | def plot_reconstruction(original, reconstruction, meta={'debug': 'true'}, interactive=False):
152 | # if not interactive:
153 | # _get_figure()
154 | picture = stitch_side_by_side(original, reconstruction)
155 | plt.imshow(picture)
156 | if not interactive:
157 | file_path = ut.to_file_name(meta, FLAGS.save_path, 'jpg')
158 | save_fig(file_path)
159 | else:
160 | plt.draw()
161 | plt.pause(0.001)
162 |
163 |
164 | # Image stitching
165 |
166 |
167 | @ut.images_to_uint8
168 | def stitch_side_by_side(original, reconstruction, proportion=1):
169 | """
170 | Stitch 2 lists of images together for convenient display in a single
171 | rectangular shape of given side proportion
172 | """
173 | # print(np.max(original), original.dtype)
174 | if not np.max(original) >= 10:
175 | print("some strange input to of the original pictures", np.max(original), original.dtype)
176 | column_picture, height = _stitch_images(original, reconstruction)
177 | picture = _reshape_column_image(column_picture, height, proportion=proportion)
178 | if picture.shape[-1] == 1:
179 | picture = picture.squeeze()
180 | return picture
181 |
182 |
183 | def _stitch_images(*args):
184 | """Recieves one or many arrays of pictures and stitches them alongside into a column picture"""
185 | assert len(args) == 2
186 | lines, height, width, channels = args[0].shape
187 | min = 0
188 | stack = args[0]
189 | # print([(type(x), x.shape) for x in args])
190 | # print(np.min(args[0]), np.max(args[0]), args[0].dtype, args[0].mean())
191 | # print(np.min(args[1]), np.max(args[1]), args[1].dtype, args[1].mean())
192 | if len(args) > 1:
193 | for i in range(len(args) - 1):
194 | stack = np.concatenate((stack, args[i + 1]), axis=2)
195 | # stack - array of lines of pictures (arr_0[0], arr_1[0], ...)
196 | # concatenate lines in one picture (height = tile_h * #lines)
197 | picture_lines = stack.reshape(lines * height, stack.shape[2], channels)
198 | picture_lines = np.hstack((
199 | picture_lines,
200 | np.ones((lines * height, 2, channels), dtype=np.uint8) * min)) # pad 2 pixels
201 |
202 | # slice/reshape to have better image proportions
203 | return picture_lines, height
204 |
205 |
206 | def _reshape_column_image(column_picture, height, proportion=1):
207 | """
208 | Take an "column" image of independent horizontal segments of 'height'
209 | and reshape it to fit into given proportion
210 |
211 | |img1|
212 | |img2| |img1 img4|
213 | |img3| => |img2 img5|
214 | |img4| |img3 |
215 | |img5|
216 |
217 | """
218 | lines = int(column_picture.shape[0] / height)
219 | width = column_picture.shape[1]
220 |
221 | column_size = int(np.ceil(np.sqrt(lines * proportion * (width / height))))
222 |
223 | # get rid of too small columns. If singl last column contains 1 or 2 pictures
224 | if lines % column_size <= lines / column_size:
225 | column_size += 1
226 |
227 | count = int(column_picture.shape[0] / height)
228 | _, _, channels = column_picture.shape
229 |
230 | picture = column_picture[0:column_size * height, :, :]
231 |
232 | for i in range(int(lines / column_size)):
233 | start, stop = column_size * height * (i + 1), column_size * height * (i + 2)
234 | if start >= len(column_picture):
235 | break
236 | if stop < len(column_picture):
237 | picture = np.hstack((picture, column_picture[start:stop]))
238 | else:
239 | last_column = np.vstack((
240 | column_picture[start:],
241 | np.ones((stop - len(column_picture), column_picture.shape[1], column_picture.shape[2]),
242 | dtype=np.uint8)))
243 | picture = np.hstack((picture, last_column))
244 |
245 | return picture
246 |
247 |
248 | figure_shape = [3095, 2352, 3]
249 |
250 |
251 | def get_figure(fig=None, shape=figure_shape):
252 | if fig is not None:
253 | return fig
254 | dpi = 300.
255 | fig = plt.figure(num=0, figsize=[shape[0]/dpi, shape[1]/dpi], dpi=dpi)
256 | fig.clf()
257 | return fig
258 |
259 |
260 | def _plot_single_cross_section(data, select, subplot):
261 | data = data[:, select]
262 | # subplot.scatter(data[:, 0], data[:, 1], s=20, lw=0, edgecolors='none', alpha=1.0,
263 | subplot.plot(data[:, 0], data[:, 1], color='black', lw=1, alpha=0.4)
264 | subplot.plot(data[[-1, 0], 0], data[[-1, 0], 1], lw=1, alpha=0.8, color='red')
265 | subplot.scatter(data[:, 0], data[:, 1], s=4, alpha=1.0, lw=0.5,
266 | c=_build_radial_colors(len(data)),
267 | marker=".",
268 | cmap=plt.cm.Spectral)
269 | # data = np.vstack((data, np.asarray([data[0, :]])))
270 | # subplot.plot(data[:, 0], data[:, 1], alpha=0.4)
271 |
272 | subplot.set_xlabel('feature %d' % select[0], labelpad=-12)
273 | subplot.set_ylabel('feature %d' % select[1], labelpad=-12)
274 | subplot.set_xlim([-0.05, 1.05])
275 | subplot.set_ylim([-0.05, 1.05])
276 | subplot.xaxis.set_ticks([0, 1])
277 | subplot.xaxis.set_major_formatter(ticker.FormatStrFormatter('%1.0f'))
278 | subplot.yaxis.set_ticks([0, 1])
279 | subplot.yaxis.set_major_formatter(ticker.FormatStrFormatter('%1.0f'))
280 |
281 |
282 | def _plot_single_cross_section_3d(data, select, subplot):
283 | data = data[:, select]
284 | # subplot.scatter(data[:, 0], data[:, 1], s=20, lw=0, edgecolors='none', alpha=1.0,
285 | subplot.plot(data[:, 0], data[:, 1], data[:, 2], color='black', lw=1, alpha=0.4)
286 | subplot.plot(data[[-1, 0], 0], data[[-1, 0], 1], data[[-1, 0], 2], lw=1, alpha=0.8, color='red')
287 | subplot.scatter(data[:, 0], data[:, 1], data[:, 2], s=4, alpha=1.0, lw=0.5,
288 | c=_build_radial_colors(len(data)),
289 | marker=".",
290 | cmap=plt.cm.Spectral)
291 | data = data[0::10]
292 | # subplot.plot(data[:, 0], data[:, 1], data[:, 2], color='black', lw=2, alpha=0.8)
293 |
294 | # data = np.vstack((data, np.asarray([data[0, :]])))
295 | # subplot.plot(data[:, 0], data[:, 1], alpha=0.4)
296 |
297 | subplot.set_xlabel('feature %d' % select[0], labelpad=-12)
298 | subplot.set_ylabel('feature %d' % select[1], labelpad=-12)
299 | subplot.set_zlabel('feature %d' % select[2], labelpad=-12)
300 | subplot.set_xlim([-0.01, 1.01])
301 | subplot.set_ylim([-0.01, 1.01])
302 | subplot.set_zlim([-0.01, 1.01])
303 | subplot.xaxis.set_ticks([0, 1])
304 | subplot.yaxis.set_ticks([0, 1])
305 | subplot.zaxis.set_ticks([0, 1])
306 | subplot.xaxis.set_major_formatter(ticker.FormatStrFormatter('%1.0f'))
307 | subplot.yaxis.set_major_formatter(ticker.FormatStrFormatter('%1.0f'))
308 |
309 |
310 | @ut.deprecated
311 | def visualize_cross_section(embeddings):
312 | features = embeddings.shape[-1]
313 | size = features - 1
314 | for i in range(features):
315 | for j in range(i + 1, features):
316 | pos = i * size + j
317 | subplot = plt.subplot(size, size, pos)
318 | _plot_single_cross_section(embeddings, [i, j], subplot)
319 |
320 | if features >= 3:
321 | pos = (size + 1) * size - size + 1
322 | subplot = plt.subplot(size + 1, size, pos)
323 | _plot_single_cross_section(embeddings, [0, 1], subplot)
324 | return size
325 |
326 |
327 | def visualize_cross_section_with_reco(embeddings):
328 | features = embeddings.shape[-1]
329 | size = features - 1
330 | for i in range(features):
331 | for j in range(i + 1, features):
332 | pos = i * (size + 1) + j
333 | subplot = plt.subplot(size, size + 1, pos)
334 | _plot_single_cross_section(embeddings, [i, j], subplot)
335 | reco_subplot = plt.subplot(1, size + 1, size + 1)
336 | reco_subplot.axis('off')
337 |
338 | if size >= 2:
339 | single_size = size if size < 4 else int(size/2)
340 | pos = single_size * (single_size + 1) - (single_size + 1) + 1
341 | subplot = plt.subplot(single_size, single_size+1, pos, projection='3d')
342 | _plot_single_cross_section_3d(embeddings, [0, 1, 2], subplot)
343 | return reco_subplot, size
344 |
345 |
346 | # cross section end
347 |
348 |
349 | def visualize_encodings(encodings, file_name=None,
350 | grid=None, skip_every=999, fast=False, fig=None, interactive=False):
351 | encodings = manual_pca(encodings)
352 | if encodings.shape[1] <= 3:
353 | return print_data_only(encodings, file_name, fig=fig, interactive=interactive)
354 |
355 | encodings = encodings[0:720]
356 | hessian_euc = dist.squareform(dist.pdist(encodings[0:720], 'euclidean'))
357 | hessian_cos = dist.squareform(dist.pdist(encodings[0:720], 'cosine'))
358 | grid = (3, 4) if grid is None else grid
359 | project_ops = []
360 |
361 | n = 2
362 | project_ops.append(("LLE ltsa N:%d" % n, mn.LocallyLinearEmbedding(10, n, method='ltsa')))
363 | project_ops.append(("LLE modified N:%d" % n, mn.LocallyLinearEmbedding(10, n, method='modified')))
364 | project_ops.append(('MDS euclidean N:%d' % n, mn.MDS(n, max_iter=300, n_init=1, dissimilarity='precomputed')))
365 | project_ops.append(("TSNE 30/2000 N:%d" % n, TSNE(perplexity=30, n_components=n, init='pca', n_iter=2000)))
366 | n = 3
367 | project_ops.append(("LLE ltsa N:%d" % n, mn.LocallyLinearEmbedding(10, n, method='ltsa')))
368 | project_ops.append(("LLE modified N:%d" % n, mn.LocallyLinearEmbedding(10, n, method='modified')))
369 | project_ops.append(('MDS euclidean N:%d' % n, mn.MDS(n, max_iter=300, n_init=1, dissimilarity='precomputed')))
370 | project_ops.append(('MDS cosine N:%d' % n, mn.MDS(n, max_iter=300, n_init=1, dissimilarity='precomputed')))
371 |
372 | plot_places = []
373 | for i in range(12):
374 | u, v = int(i / (skip_every - 1)), i % (skip_every - 1)
375 | j = v + u * skip_every + 1
376 | plot_places.append(j)
377 |
378 | fig = get_figure(fig)
379 | fig.set_size_inches(fig.get_size_inches()[0] * grid[0] / 1.,
380 | fig.get_size_inches()[1] * grid[1] / 2.0)
381 |
382 | for i, (name, manifold) in enumerate(project_ops):
383 | is3d = 'N:3' in name
384 |
385 | try:
386 | if is3d:
387 | subplot = plt.subplot(grid[0], grid[1], plot_places[i], projection='3d')
388 | else:
389 | subplot = plt.subplot(grid[0], grid[1], plot_places[i])
390 |
391 | data_source = encodings if not _needs_hessian(manifold) else \
392 | (hessian_cos if 'cosine' in name else hessian_euc)
393 | projections = manifold.fit_transform(data_source)
394 | scatter(subplot, projections, is3d, _build_radial_colors(len(data_source)))
395 | subplot.set_title(name)
396 | except:
397 | print(name, "Unexpected error: ", sys.exc_info()[0], sys.exc_info()[1] if len(sys.exc_info()) > 1 else '')
398 |
399 | visualize_data_same(encodings, grid=grid, places=plot_places[-4:])
400 | if not interactive:
401 | save_fig(file_name, fig)
402 | ut.print_time('visualization finished')
403 |
404 |
405 | def save_fig(file_name, fig=None):
406 | if not file_name:
407 | plt.show()
408 | else:
409 | plt.savefig(file_name, dpi=300, facecolor='w', edgecolor='w',
410 | transparent=False, bbox_inches='tight', pad_inches=0.1,
411 | frameon=None)
412 | # plt.close('all')
413 |
414 |
415 | def _random_split(sequence, length, original):
416 | if sequence is None or len(sequence) < length:
417 | sequence = original.copy()
418 | sequence = np.random.permutation(sequence)
419 | return sequence[:length], sequence[length:]
420 |
421 |
422 | @ut.deprecated
423 | def visualize_data_same(data, grid, places):
424 | assert len(places) == 4
425 |
426 | all_dimensions = np.arange(0, data.shape[1]).astype(np.int8)
427 | first_proj, left = _random_split(None, 2, all_dimensions)
428 | first_color_indexes, _ = _random_split(left, 3, all_dimensions)
429 | first_color = _data_to_colors(data, first_color_indexes - 2)
430 |
431 | second_proj, left = _random_split(left, 2, all_dimensions)
432 | second_color = _build_radial_colors(len(data))
433 |
434 | third_proj, left = _random_split(left, 3, all_dimensions)
435 | third_color_indexes, _ = _random_split(left, 3, all_dimensions)
436 | third_color = _data_to_colors(data, third_color_indexes)
437 |
438 | forth_proj = np.argsort(data.std(axis=0))[::-1][0:3]
439 | forth_color = _build_radial_colors(len(data))
440 |
441 | for i, (projection, color) in enumerate([
442 | (first_proj, first_color),
443 | (second_proj, second_color),
444 | (third_proj, third_color),
445 | (forth_proj, forth_color)]
446 | ):
447 | points = np.transpose(data[:, projection])
448 |
449 | if len(projection) == 2:
450 | subplot = plt.subplot(grid[0], grid[1], places[i])
451 | subplot.scatter(points[0], points[1], c=color, cmap=COLOR_MAP, picker=PICKER_SENSITIVITY)
452 | else:
453 | subplot = plt.subplot(grid[0], grid[1], places[i], projection='3d')
454 | subplot.scatter(points[0], points[1], points[2], c=color, cmap=COLOR_MAP, picker=PICKER_SENSITIVITY)
455 | subplot.set_title('Data %s %s' % (str(projection), 'sequntial color' if i % 2 == 1 else ''))
456 |
457 |
458 | @ut.deprecated
459 | def visualize_data_same_deprecated(data, grid, places, dims_as_colors=False):
460 | assert len(places) == 4
461 | dimensions = np.arange(0, np.min([6, data.shape[1]])).astype(np.int)
462 | assert len(dimensions) == data.shape[1] or len(dimensions) == 6
463 | projections = [dimensions[x] for x in [[0, 1], [-1, -2], [0, 1, 2], [-1, -2, -3]]]
464 | colors = _build_radial_colors(len(data))
465 |
466 | for i, dims in enumerate(projections):
467 | points = np.transpose(data[:, dims])
468 | if dims_as_colors:
469 | colors = _data_to_colors(np.delete(data.copy(), dims, axis=1))
470 |
471 | if len(dims) == 2:
472 | subplot = plt.subplot(grid[0], grid[1], places[i])
473 | subplot.scatter(points[0], points[1], c=colors, cmap=COLOR_MAP)
474 | else:
475 | subplot = plt.subplot(grid[0], grid[1], places[i], projection='3d')
476 | subplot.scatter(points[0], points[1], points[2], c=colors, cmap=COLOR_MAP)
477 | subplot.set_title('Data %s' % str(dims))
478 |
479 |
480 | def _duplicate_array(array, repeats=None, total_length=None):
481 | assert repeats is not None or total_length is not None
482 |
483 | if repeats is None:
484 | repeats = int(np.ceil(total_length / len(array)))
485 | res = array.copy()
486 | for i in range(repeats - 1):
487 | res = np.concatenate((res, array))
488 | return res if total_length is None else res[:total_length]
489 |
490 |
491 | def _duplicate_array_repeat(array, repeats=None, total_length=None):
492 | assert repeats is not None or total_length is not None
493 |
494 | if repeats is None:
495 | repeats = int(np.ceil(total_length / len(array)))
496 | parts = [array for i in range(repeats)]
497 | whole = np.stack(parts, axis=1)
498 | whole = whole.reshape(np.prod(whole.shape))
499 | return whole if total_length is None else whole[:total_length]
500 |
501 |
502 | def _build_radial_colors(length):
503 | colors = np.arange(0, 180)
504 | # colors = np.concatenate((colors, colors[::-1]))
505 | colors = _duplicate_array_repeat(colors, total_length=length)
506 | return colors
507 |
508 |
509 | def _data_to_colors(data, indexes=None):
510 | color_data = data[:, indexes] if indexes is not None else data
511 | shape = color_data.shape
512 |
513 | if shape[1] < 3:
514 | add = 3 - shape[1]
515 | add = np.ones((shape[0], add)) * 0.5
516 | color_data = np.concatenate((color_data, add), axis=1)
517 | elif shape[1] > 3:
518 | color_data = color_data[:, 0:3]
519 |
520 | if np.max(color_data) <= 1:
521 | color_data *= 256
522 | color_data = color_data.astype(np.int32)
523 | assert np.mean(color_data) <= 256
524 | color_data[color_data > 255] = 255
525 | color_data *= np.asarray([256 ** 2, 256, 1])
526 |
527 | color_data = np.sum(color_data, axis=1)
528 | color_data = ["#%06x" % c for c in color_data]
529 | return color_data
530 |
531 |
532 | def visualize_available_data(root='./', reembed=True, with_mds=False):
533 | tf.app.flags.DEFINE_string('suffix', 'grid', '')
534 | FLAGS.suffix = '__' if with_mds else 'grid'
535 | assert os.path.exists(root)
536 |
537 | files = _list_embedding_files(root, reembed=reembed)
538 | total_files = len(files)
539 |
540 | for i, file in enumerate(files):
541 | print('%d/%d %s' % (i + 1, total_files, file))
542 | data = np.loadtxt(file)
543 | png_name = file.replace('.txt', '_pca.png')
544 |
545 | visualize_encodings(data, file_name=png_name)
546 |
547 |
548 | def _list_embedding_files(root, reembed=False):
549 | ecndoding_files = []
550 | for root, dirs, files in os.walk(root):
551 | if '/tmp' in root:
552 | for file in files:
553 | if '.txt' in file and 'meta' not in file:
554 | full_path = os.path.join(root, file)
555 | if not reembed:
556 | vis_file = full_path.replace('.txt', '_pca.png')
557 | if os.path.exists(vis_file):
558 | continue
559 | ecndoding_files.append(full_path)
560 | return ecndoding_files
561 |
562 |
563 | if __name__ == '__main__':
564 | path = '../../encodings__e|500__z_ac|96.5751.txt'
565 | x = np.loadtxt(path)
566 | # x = np.random.rand(100, 5)
567 | x = manual_pca(x)
568 | x = x[:360]
569 | visualize_cross_section_with_reco(x)
570 | plt.tight_layout()
571 | plt.show()
572 |
--------------------------------------------------------------------------------
/visualize_latest.py:
--------------------------------------------------------------------------------
1 | import visualization as vi
2 | import utils as ut
3 | import input as inp
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | import tensorflow as tf
7 | import os
8 | import sys
9 | import visualization as vis
10 |
11 | from mpl_toolkits.mplot3d import Axes3D
12 |
13 | # Next line to silence pyflakes. This import is needed.
14 | Axes3D
15 |
16 | FLAGS = tf.app.flags.FLAGS
17 |
18 |
19 | def print_data(data, fig, subplot, is_3d=True):
20 | colors = np.arange(0, 180)
21 | colors = np.concatenate((colors, colors[::-1]))
22 | colors = vi._duplicate_array(colors, total_length=len(data))
23 |
24 | if is_3d:
25 | subplot = fig.add_subplot(subplot, projection='3d')
26 | subplot.set_title('All data')
27 | subplot.scatter(data[:, 0], data[:, 1], data[:, 2], c=colors, cmap=plt.cm.Spectral, picker=5)
28 | else:
29 | subsample = data[0:360] if len(data) < 2000 else data[0:720]
30 | subsample = np.concatenate((subsample, subsample))[0:len(subsample)+1]
31 | ut.print_info('subsample shape %s' % str(subsample.shape))
32 | subsample_colors = colors[0:len(subsample)]
33 | subplot = fig.add_subplot(subplot)
34 | subplot.set_title('First 360 elem')
35 | subplot.plot(subsample[:, 0], subsample[:, 1], picker=0)
36 | subplot.plot(subsample[0, 0], subsample[0, 1], picker=0)
37 | subplot.scatter(subsample[:, 0], subsample[:, 1], s=50, c=subsample_colors,
38 | cmap=plt.cm.Spectral, picker=5)
39 | return subplot
40 |
41 |
42 | class EncodingVisualizer:
43 | def __init__(self, fig, data):
44 | self.data = data
45 | self.fig = fig
46 | vi.visualize_encodings(data, grid=(3, 5), skip_every=5, fast=fast, fig=fig, interactive=True)
47 | plt.subplot(155).set_title(', '.join('hold on'))
48 | # fig.canvas.mpl_connect('button_press_event', self.on_click)
49 | fig.canvas.mpl_connect('pick_event', self.on_pick)
50 | try:
51 | # if True:
52 | ut.print_info('Checkpoint: %s' % FLAGS.load_from_checkpoint)
53 | self.model = dm.DoomModel()
54 | self.reconstructions = self.model.decode(data)
55 | except:
56 | ut.print_info("Model could not load from checkpoint %s" % str(sys.exc_info()), color=31)
57 | self.original_data, _ = inp.get_images(FLAGS.input_path)
58 | self.reconstructions = np.zeros(self.original_data.shape).astype(np.uint8)
59 | ut.print_info('INPUT: %s' % FLAGS.input_path.split('/')[-3])
60 | self.original_data, _ = inp.get_images(FLAGS.input_path)
61 |
62 |
63 | def on_pick(self, event):
64 | print(event)
65 | ind = event.ind
66 | print(ind)
67 | print(any([x for x in ind if x < 20]))
68 | orig = self.original_data[ind]
69 | reco = self.reconstructions[ind]
70 | column_picture, height = vi._stitch_images(orig, reco)
71 | picture = vi._reshape_column_image(column_picture, height, proportion=3)
72 |
73 | title = ''
74 | for i in range(len(ind)):
75 | title += ' ' + str(ind[i])
76 | if (i+1) % 8 == 0:
77 | title += '\n'
78 | plt.subplot(155).set_title(title)
79 | plt.subplot(155).imshow(picture)
80 | plt.show()
81 |
82 | def on_click(self, event):
83 | print('click', event)
84 |
85 |
86 | def visualize_latest_from_visualization_folder(folder='./visualizations/', file=None):
87 | if file is None:
88 | file = ut.get_latest_file(folder, filter=r'.*\d+\.txt$')
89 | ut.print_info('Encoding file: %s' % file.split('/')[-1])
90 | data = np.loadtxt(file) # [0:360]
91 | fig = plt.figure()
92 | vi.visualize_encodings(data, fast=fast, fig=fig, interactive=True)
93 | fig.suptitle(file.split('/')[-1])
94 | fig.tight_layout()
95 | plt.show()
96 |
97 |
98 | def visualize_from_checkpoint(checkpoint, epoch=None):
99 | assert os.path.exists(checkpoint)
100 | FLAGS.load_from_checkpoint = checkpoint
101 | file_filter = r'.*\d+\.txt$' if epoch is None else r'.*e\|%d.*' % epoch
102 | latest_file = ut.get_latest_file(folder=checkpoint, filter=file_filter)
103 | print(latest_file)
104 | ut.print_info('Encoding file: %s' % latest_file.split('/')[-1])
105 | data = np.loadtxt(latest_file)
106 | fig = plt.figure()
107 | fig.set_size_inches(fig.get_size_inches()[0] * 2, fig.get_size_inches()[1] * 2)
108 | entity = EncodingVisualizer(fig, data)
109 | # fig.tight_layout()
110 | plt.show()
111 |
112 |
113 | fast = True
114 |
115 | if __name__ == '__main__':
116 |
117 | cwd = os.getcwd()
118 | # cwd = '/mnt/code/vd/TensorFlow_DCIGN/tmp/pred.16c3s2_32c3s2_32c3s2_16c3_f80_f8__i_grid.28c.4'
119 | latest = ut.get_latest_file(cwd, filter=r'.*_suf\.encodings\.npy$')
120 | print(latest)
121 | data = np.load(latest).item()
122 | # print(type(data))
123 | # i = data.item()
124 | # print(type(i))
125 | # print(i.shape)
126 | # print(data['enc'])
127 |
128 | # print(data)
129 | x = data['enc']
130 |
131 | # print(x)
132 |
133 | fig = vis.plot_encoding_crosssection(
134 | x,
135 | '',
136 | data['blu'],
137 | data['rec'],
138 | interactive=True)
139 | fig.set_size_inches(fig.get_size_inches()[0] * 2, fig.get_size_inches()[1] * 2)
140 | # plt.tight_layout()
141 | plt.show()
142 |
143 | # path = sys.argv[1] if len(sys.argv) > 1 \
144 | # else './tmp/ml__act|sigmoid__bs|30__h|500|10|500__init|na__inp|8pd3__lr|0.00003__opt|AO__seq|03'
145 | # epoch = int(sys.argv[2]) if len(sys.argv) > 2 else None
146 | #
147 | # # path = './tmp/doom_bs__act|sigmoid__bs|30__h|500|12|500__init|na__inp|8pd3__lr|0.0004__opt|AO/'
148 | #
149 | # # import os
150 | # # print('really? ', )
151 | #
152 | # if path is None:
153 | # ut.print_info('Visualizing latest file from visualization folder')
154 | # visualize_latest_from_visualization_folder()
155 | # exit(0)
156 | #
157 | # is_embedding = '.txt' in path
158 | # if is_embedding:
159 | # ut.print_info('Visualizing encoding file')
160 | # visualize_latest_from_visualization_folder(file=path)
161 | # exit(0)
162 | #
163 | # is_checkpoint = '/tmp' in path
164 | # if is_checkpoint:
165 | # print('so', path)
166 | # ut.print_info('Visualizing checkpoint data')
167 | # visualize_from_checkpoint(checkpoint=path, epoch=epoch)
168 | # else:
169 | # ut.print_info('Visualizing latest from folder', color=34)
170 | # visualize_latest_from_visualization_folder(folder=path)
171 |
--------------------------------------------------------------------------------